Skip to main content

oxicuda_levelzero/
backend.rs

1//! [`LevelZeroBackend`] — the main entry point for the oxicuda-levelzero crate.
2//!
3//! Implements the [`ComputeBackend`] trait from `oxicuda-backend` using
4//! Intel's Level Zero API for GPU compute on Linux and Windows.
5
6use std::sync::Arc;
7
8use oxicuda_backend::{
9    BackendError, BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
10};
11
12use crate::{device::LevelZeroDevice, memory::LevelZeroMemoryManager};
13
14// ─── Backend struct ───────────────────────────────────────────────────────────
15
16/// Intel Level Zero GPU compute backend.
17///
18/// On Linux and Windows this selects the first Intel GPU via the Level Zero
19/// loader library (`libze_loader.so` / `ze_loader.dll`) and allocates device
20/// memory through the Level Zero memory model.
21///
22/// On macOS every operation returns [`BackendError::DeviceError`] wrapping
23/// [`crate::error::LevelZeroError::UnsupportedPlatform`].
24///
25/// # Lifecycle
26///
27/// 1. `LevelZeroBackend::new()` — create an uninitialised backend.
28/// 2. `init()` — load the Level Zero driver and select a GPU.
29/// 3. Use `alloc`, `copy_htod`, compute ops, `copy_dtoh`, `free`.
30/// 4. `synchronize()` — wait for all pending GPU work to finish.
31#[derive(Debug)]
32pub struct LevelZeroBackend {
33    device: Option<Arc<LevelZeroDevice>>,
34    memory: Option<Arc<LevelZeroMemoryManager>>,
35    initialized: bool,
36}
37
38impl LevelZeroBackend {
39    /// Create a new, uninitialised Level Zero backend.
40    pub fn new() -> Self {
41        Self {
42            device: None,
43            memory: None,
44            initialized: false,
45        }
46    }
47
48    /// Return an error if the backend has not been initialised yet.
49    fn check_init(&self) -> BackendResult<()> {
50        if self.initialized {
51            Ok(())
52        } else {
53            Err(BackendError::NotInitialized)
54        }
55    }
56
57    /// Convenience accessor: get the memory manager or return `NotInitialized`.
58    fn memory(&self) -> BackendResult<&Arc<LevelZeroMemoryManager>> {
59        self.memory.as_ref().ok_or(BackendError::NotInitialized)
60    }
61}
62
63impl Default for LevelZeroBackend {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69// ─── ComputeBackend impl ──────────────────────────────────────────────────────
70
71impl ComputeBackend for LevelZeroBackend {
72    fn name(&self) -> &str {
73        "level-zero"
74    }
75
76    fn init(&mut self) -> BackendResult<()> {
77        if self.initialized {
78            return Ok(());
79        }
80        match LevelZeroDevice::new() {
81            Ok(dev) => {
82                let dev = Arc::new(dev);
83                tracing::info!("Level Zero backend initialised on: {}", dev.name());
84                let memory = LevelZeroMemoryManager::new(Arc::clone(&dev));
85                self.device = Some(dev);
86                self.memory = Some(Arc::new(memory));
87                self.initialized = true;
88                Ok(())
89            }
90            Err(e) => Err(BackendError::from(e)),
91        }
92    }
93
94    fn is_initialized(&self) -> bool {
95        self.initialized
96    }
97
98    // ── Compute operations ────────────────────────────────────────────────────
99
100    fn gemm(
101        &self,
102        _trans_a: BackendTranspose,
103        _trans_b: BackendTranspose,
104        m: usize,
105        n: usize,
106        k: usize,
107        alpha: f64,
108        a_ptr: u64,
109        _lda: usize,
110        b_ptr: u64,
111        _ldb: usize,
112        beta: f64,
113        c_ptr: u64,
114        _ldc: usize,
115    ) -> BackendResult<()> {
116        self.check_init()?;
117        if m == 0 || n == 0 || k == 0 {
118            return Ok(());
119        }
120        self.dispatch_gemm(m, n, k, alpha as f32, a_ptr, b_ptr, beta as f32, c_ptr)
121    }
122
123    #[allow(clippy::too_many_arguments)]
124    fn batched_gemm(
125        &self,
126        _trans_a: BackendTranspose,
127        _trans_b: BackendTranspose,
128        m: usize,
129        n: usize,
130        k: usize,
131        alpha: f64,
132        a_ptr: u64,
133        _lda: usize,
134        stride_a: usize,
135        b_ptr: u64,
136        _ldb: usize,
137        stride_b: usize,
138        beta: f64,
139        c_ptr: u64,
140        _ldc: usize,
141        stride_c: usize,
142        batch_count: usize,
143    ) -> BackendResult<()> {
144        self.check_init()?;
145        if batch_count == 0 || m == 0 || n == 0 || k == 0 {
146            return Ok(());
147        }
148        self.dispatch_batched_gemm(
149            m,
150            n,
151            k,
152            alpha as f32,
153            a_ptr,
154            b_ptr,
155            beta as f32,
156            c_ptr,
157            batch_count,
158            stride_a,
159            stride_b,
160            stride_c,
161        )
162    }
163
164    fn conv2d_forward(
165        &self,
166        input_ptr: u64,
167        input_shape: &[usize],
168        filter_ptr: u64,
169        filter_shape: &[usize],
170        output_ptr: u64,
171        output_shape: &[usize],
172        stride: &[usize],
173        padding: &[usize],
174    ) -> BackendResult<()> {
175        self.check_init()?;
176
177        if input_shape.len() != 4 {
178            return Err(BackendError::InvalidArgument(
179                "input_shape must have 4 elements (NCHW)".into(),
180            ));
181        }
182        if filter_shape.len() != 4 {
183            return Err(BackendError::InvalidArgument(
184                "filter_shape must have 4 elements (KCFHFW)".into(),
185            ));
186        }
187        if output_shape.len() != 4 {
188            return Err(BackendError::InvalidArgument(
189                "output_shape must have 4 elements (NKOhOw)".into(),
190            ));
191        }
192        if stride.len() != 2 {
193            return Err(BackendError::InvalidArgument(
194                "stride must have 2 elements [sh, sw]".into(),
195            ));
196        }
197        if padding.len() != 2 {
198            return Err(BackendError::InvalidArgument(
199                "padding must have 2 elements [ph, pw]".into(),
200            ));
201        }
202
203        let n = input_shape[0];
204        let c_in = input_shape[1];
205        let h_in = input_shape[2];
206        let w_in = input_shape[3];
207        let k_out = filter_shape[0];
208        let fh = filter_shape[2];
209        let fw = filter_shape[3];
210        let o_h = output_shape[2];
211        let o_w = output_shape[3];
212        let stride_h = stride[0];
213        let stride_w = stride[1];
214        let pad_h = padding[0];
215        let pad_w = padding[1];
216
217        // CPU fallback: copy input + filter from device
218        let in_len = n * c_in * h_in * w_in;
219        let flt_len = k_out * c_in * fh * fw;
220        let out_len = n * k_out * o_h * o_w;
221
222        let mut in_bytes = vec![0u8; in_len * 4];
223        self.copy_dtoh(&mut in_bytes, input_ptr)?;
224        let inp: Vec<f32> = in_bytes
225            .chunks_exact(4)
226            .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
227            .collect();
228
229        let mut flt_bytes = vec![0u8; flt_len * 4];
230        self.copy_dtoh(&mut flt_bytes, filter_ptr)?;
231        let flt: Vec<f32> = flt_bytes
232            .chunks_exact(4)
233            .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
234            .collect();
235
236        // NCHW convolution
237        let mut out = vec![0.0f32; out_len];
238        for b_idx in 0..n {
239            for kf in 0..k_out {
240                for oy in 0..o_h {
241                    for ox in 0..o_w {
242                        let mut acc = 0.0f32;
243                        for ci in 0..c_in {
244                            for fy in 0..fh {
245                                for fx in 0..fw {
246                                    let iy = (oy * stride_h + fy) as isize - pad_h as isize;
247                                    let ix = (ox * stride_w + fx) as isize - pad_w as isize;
248                                    if iy >= 0
249                                        && (iy as usize) < h_in
250                                        && ix >= 0
251                                        && (ix as usize) < w_in
252                                    {
253                                        let iy = iy as usize;
254                                        let ix = ix as usize;
255                                        acc += inp[((b_idx * c_in + ci) * h_in + iy) * w_in + ix]
256                                            * flt[((kf * c_in + ci) * fh + fy) * fw + fx];
257                                    }
258                                }
259                            }
260                        }
261                        out[((b_idx * k_out + kf) * o_h + oy) * o_w + ox] = acc;
262                    }
263                }
264            }
265        }
266
267        let out_bytes: Vec<u8> = out.iter().flat_map(|f| f.to_ne_bytes()).collect();
268        self.copy_htod(output_ptr, &out_bytes)
269    }
270
271    fn attention(
272        &self,
273        q_ptr: u64,
274        k_ptr: u64,
275        v_ptr: u64,
276        o_ptr: u64,
277        batch: usize,
278        heads: usize,
279        seq_q: usize,
280        seq_kv: usize,
281        head_dim: usize,
282        scale: f64,
283        causal: bool,
284    ) -> BackendResult<()> {
285        self.check_init()?;
286
287        if seq_q == 0 || seq_kv == 0 || head_dim == 0 {
288            return Err(BackendError::InvalidArgument(
289                "seq_q, seq_kv, and head_dim must all be > 0".into(),
290            ));
291        }
292        if scale <= 0.0 || !scale.is_finite() {
293            return Err(BackendError::InvalidArgument(format!(
294                "scale must be a positive finite number, got {scale}"
295            )));
296        }
297
298        let batch_heads = batch * heads;
299        let scale_f32 = scale as f32;
300
301        // CPU fallback: copy Q, K, V from device
302        let q_len = batch_heads * seq_q * head_dim;
303        let kv_len = batch_heads * seq_kv * head_dim;
304
305        let mut q_bytes = vec![0u8; q_len * 4];
306        self.copy_dtoh(&mut q_bytes, q_ptr)?;
307        let q: Vec<f32> = q_bytes
308            .chunks_exact(4)
309            .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
310            .collect();
311
312        let mut k_bytes = vec![0u8; kv_len * 4];
313        self.copy_dtoh(&mut k_bytes, k_ptr)?;
314        let k: Vec<f32> = k_bytes
315            .chunks_exact(4)
316            .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
317            .collect();
318
319        let mut v_bytes = vec![0u8; kv_len * 4];
320        self.copy_dtoh(&mut v_bytes, v_ptr)?;
321        let v: Vec<f32> = v_bytes
322            .chunks_exact(4)
323            .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
324            .collect();
325
326        // Numerically-stable scaled dot-product attention
327        let mut output = vec![0.0f32; q_len];
328
329        for bh in 0..batch_heads {
330            for sq in 0..seq_q {
331                let q_off = (bh * seq_q + sq) * head_dim;
332                let o_off = q_off;
333
334                // Pass 1: find max score
335                let mut max_score = f32::NEG_INFINITY;
336                for sk in 0..seq_kv {
337                    if causal && sk > sq {
338                        continue;
339                    }
340                    let k_off = (bh * seq_kv + sk) * head_dim;
341                    let mut dot = 0.0f32;
342                    for d in 0..head_dim {
343                        dot += q[q_off + d] * k[k_off + d];
344                    }
345                    let score = dot * scale_f32;
346                    if score > max_score {
347                        max_score = score;
348                    }
349                }
350
351                if max_score == f32::NEG_INFINITY {
352                    max_score = 0.0;
353                }
354
355                // Pass 2: accumulate exp-weighted V
356                let mut sum_exp = 0.0f32;
357                for sk in 0..seq_kv {
358                    if causal && sk > sq {
359                        continue;
360                    }
361                    let k_off = (bh * seq_kv + sk) * head_dim;
362                    let v_off = (bh * seq_kv + sk) * head_dim;
363                    let mut dot = 0.0f32;
364                    for d in 0..head_dim {
365                        dot += q[q_off + d] * k[k_off + d];
366                    }
367                    let w = (dot * scale_f32 - max_score).exp();
368                    sum_exp += w;
369                    for d in 0..head_dim {
370                        output[o_off + d] += w * v[v_off + d];
371                    }
372                }
373
374                // Normalize
375                if sum_exp > 0.0 {
376                    for d in 0..head_dim {
377                        output[o_off + d] /= sum_exp;
378                    }
379                }
380            }
381        }
382
383        let o_bytes: Vec<u8> = output.iter().flat_map(|f| f.to_ne_bytes()).collect();
384        self.copy_htod(o_ptr, &o_bytes)
385    }
386
387    fn reduce(
388        &self,
389        op: ReduceOp,
390        input_ptr: u64,
391        output_ptr: u64,
392        shape: &[usize],
393        axis: usize,
394    ) -> BackendResult<()> {
395        self.check_init()?;
396
397        if shape.is_empty() {
398            return Err(BackendError::InvalidArgument(
399                "shape must not be empty".into(),
400            ));
401        }
402        if axis >= shape.len() {
403            return Err(BackendError::InvalidArgument(format!(
404                "axis {axis} is out of bounds for shape of length {}",
405                shape.len()
406            )));
407        }
408
409        self.dispatch_reduce(op, input_ptr, output_ptr, shape, axis)
410    }
411
412    fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()> {
413        self.check_init()?;
414        if n == 0 {
415            return Ok(());
416        }
417        self.dispatch_unary(op, input_ptr, output_ptr, n)
418    }
419
420    fn binary(
421        &self,
422        op: BinaryOp,
423        a_ptr: u64,
424        b_ptr: u64,
425        output_ptr: u64,
426        n: usize,
427    ) -> BackendResult<()> {
428        self.check_init()?;
429        if n == 0 {
430            return Ok(());
431        }
432        self.dispatch_binary(op, a_ptr, b_ptr, output_ptr, n)
433    }
434
435    // ── Synchronisation ───────────────────────────────────────────────────────
436
437    fn synchronize(&self) -> BackendResult<()> {
438        self.check_init()?;
439
440        #[cfg(any(target_os = "linux", target_os = "windows"))]
441        {
442            if let Some(dev) = &self.device {
443                let api = &dev.api;
444                let queue = dev.queue;
445                // SAFETY: `queue` is a valid command queue handle and the
446                // backend is initialized.  u64::MAX means "wait indefinitely".
447                let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
448                if rc != 0 {
449                    return Err(BackendError::DeviceError(format!(
450                        "zeCommandQueueSynchronize failed: 0x{rc:08x}"
451                    )));
452                }
453            }
454        }
455
456        Ok(())
457    }
458
459    // ── Memory management ─────────────────────────────────────────────────────
460
461    fn alloc(&self, bytes: usize) -> BackendResult<u64> {
462        self.check_init()?;
463        if bytes == 0 {
464            return Err(BackendError::InvalidArgument(
465                "cannot allocate 0 bytes".into(),
466            ));
467        }
468        self.memory()?.alloc(bytes).map_err(BackendError::from)
469    }
470
471    fn free(&self, ptr: u64) -> BackendResult<()> {
472        self.check_init()?;
473        self.memory()?.free(ptr).map_err(BackendError::from)
474    }
475
476    fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
477        self.check_init()?;
478        if src.is_empty() {
479            return Ok(());
480        }
481        self.memory()?
482            .copy_to_device(dst, src)
483            .map_err(BackendError::from)
484    }
485
486    fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
487        self.check_init()?;
488        if dst.is_empty() {
489            return Ok(());
490        }
491        self.memory()?
492            .copy_from_device(dst, src)
493            .map_err(BackendError::from)
494    }
495}
496
497// ─── Dispatch helpers ────────────────────────────────────────────────────────
498
499/// Workgroup size matching the SPIR-V LocalSize declaration.
500const WORKGROUP_SIZE: u32 = crate::spirv::WORKGROUP_SIZE;
501
502/// A kernel argument value for the Level Zero dispatch pipeline.
503#[cfg_attr(not(any(target_os = "linux", target_os = "windows")), allow(dead_code))]
504enum KernelArg {
505    /// Buffer handle — resolved to a raw device pointer at dispatch time.
506    Buffer(u64),
507    /// 32-bit unsigned integer scalar.
508    U32(u32),
509    /// 32-bit float scalar.
510    F32(f32),
511}
512
513impl LevelZeroBackend {
514    /// Dispatch a SPIR-V compute kernel via Level Zero.
515    ///
516    /// 1. Build a Level Zero module from `spv_words`.
517    /// 2. Create a kernel named `"main"` from the module.
518    /// 3. Set group size and kernel arguments.
519    /// 4. Append a launch to a command list, execute, and wait.
520    /// 5. Clean up all Level Zero objects.
521    fn run_kernel(
522        &self,
523        spv_words: &[u32],
524        args: &[KernelArg],
525        workgroups: u32,
526    ) -> BackendResult<()> {
527        #[cfg(any(target_os = "linux", target_os = "windows"))]
528        {
529            use std::ffi::c_void;
530
531            use crate::device::{
532                ZE_MODULE_FORMAT_IL_SPIRV, ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
533                ZE_STRUCTURE_TYPE_KERNEL_DESC, ZE_STRUCTURE_TYPE_MODULE_DESC, ZeCommandListDesc,
534                ZeGroupCount, ZeKernelDesc, ZeKernelHandle, ZeModuleDesc, ZeModuleHandle,
535            };
536
537            let device = self.device.as_ref().ok_or(BackendError::NotInitialized)?;
538            let memory = self.memory()?;
539            let api = &device.api;
540            let context = device.context;
541            let dev_handle = device.device;
542            let queue = device.queue;
543
544            // ── 1. SPIR-V words → bytes ──
545            let spv_bytes: Vec<u8> = spv_words.iter().flat_map(|w| w.to_ne_bytes()).collect();
546
547            // ── 2. Create module ──
548            let module_desc = ZeModuleDesc {
549                stype: ZE_STRUCTURE_TYPE_MODULE_DESC,
550                p_next: std::ptr::null(),
551                format: ZE_MODULE_FORMAT_IL_SPIRV,
552                input_size: spv_bytes.len(),
553                p_input_module: spv_bytes.as_ptr(),
554                p_build_flags: std::ptr::null(),
555                p_constants: std::ptr::null(),
556            };
557            let mut module: ZeModuleHandle = std::ptr::null_mut();
558            let rc = unsafe {
559                (api.ze_module_create)(
560                    context,
561                    dev_handle,
562                    &module_desc,
563                    &mut module as *mut ZeModuleHandle,
564                    std::ptr::null_mut(),
565                )
566            };
567            if rc != 0 {
568                return Err(BackendError::DeviceError(format!(
569                    "zeModuleCreate failed: 0x{rc:08x}"
570                )));
571            }
572
573            // ── 3. Create kernel ──
574            let kernel_name = b"main\0";
575            let kernel_desc = ZeKernelDesc {
576                stype: ZE_STRUCTURE_TYPE_KERNEL_DESC,
577                p_next: std::ptr::null(),
578                flags: 0,
579                p_kernel_name: kernel_name.as_ptr(),
580            };
581            let mut kernel: ZeKernelHandle = std::ptr::null_mut();
582            let rc = unsafe {
583                (api.ze_kernel_create)(module, &kernel_desc, &mut kernel as *mut ZeKernelHandle)
584            };
585            if rc != 0 {
586                unsafe { (api.ze_module_destroy)(module) };
587                return Err(BackendError::DeviceError(format!(
588                    "zeKernelCreate failed: 0x{rc:08x}"
589                )));
590            }
591
592            // ── 4. Set group size ──
593            let rc = unsafe { (api.ze_kernel_set_group_size)(kernel, WORKGROUP_SIZE, 1, 1) };
594            if rc != 0 {
595                unsafe {
596                    (api.ze_kernel_destroy)(kernel);
597                    (api.ze_module_destroy)(module);
598                }
599                return Err(BackendError::DeviceError(format!(
600                    "zeKernelSetGroupSize failed: 0x{rc:08x}"
601                )));
602            }
603
604            // ── 5. Set kernel arguments ──
605            for (idx, arg) in args.iter().enumerate() {
606                let rc = match arg {
607                    KernelArg::Buffer(handle) => {
608                        let dev_ptr = memory.device_ptr(*handle).map_err(|e| {
609                            unsafe {
610                                (api.ze_kernel_destroy)(kernel);
611                                (api.ze_module_destroy)(module);
612                            }
613                            BackendError::from(e)
614                        })?;
615                        unsafe {
616                            (api.ze_kernel_set_argument_value)(
617                                kernel,
618                                idx as u32,
619                                std::mem::size_of::<*mut c_void>(),
620                                &dev_ptr as *const *mut c_void as *const c_void,
621                            )
622                        }
623                    }
624                    KernelArg::U32(val) => unsafe {
625                        (api.ze_kernel_set_argument_value)(
626                            kernel,
627                            idx as u32,
628                            std::mem::size_of::<u32>(),
629                            val as *const u32 as *const c_void,
630                        )
631                    },
632                    KernelArg::F32(val) => unsafe {
633                        (api.ze_kernel_set_argument_value)(
634                            kernel,
635                            idx as u32,
636                            std::mem::size_of::<f32>(),
637                            val as *const f32 as *const c_void,
638                        )
639                    },
640                };
641                if rc != 0 {
642                    unsafe {
643                        (api.ze_kernel_destroy)(kernel);
644                        (api.ze_module_destroy)(module);
645                    }
646                    return Err(BackendError::DeviceError(format!(
647                        "zeKernelSetArgumentValue(arg={idx}) failed: 0x{rc:08x}"
648                    )));
649                }
650            }
651
652            // ── 6. Create command list ──
653            let list_desc = ZeCommandListDesc {
654                stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
655                p_next: std::ptr::null(),
656                command_queue_group_ordinal: 0,
657                flags: 0,
658            };
659            let mut list = std::ptr::null_mut();
660            let rc =
661                unsafe { (api.ze_command_list_create)(context, dev_handle, &list_desc, &mut list) };
662            if rc != 0 {
663                unsafe {
664                    (api.ze_kernel_destroy)(kernel);
665                    (api.ze_module_destroy)(module);
666                }
667                return Err(BackendError::DeviceError(format!(
668                    "zeCommandListCreate failed: 0x{rc:08x}"
669                )));
670            }
671
672            // ── 7. Append launch kernel ──
673            let group_count = ZeGroupCount {
674                group_count_x: workgroups,
675                group_count_y: 1,
676                group_count_z: 1,
677            };
678            let rc = unsafe {
679                (api.ze_command_list_append_launch_kernel)(
680                    list,
681                    kernel,
682                    &group_count,
683                    0,
684                    0,
685                    std::ptr::null(),
686                )
687            };
688            if rc != 0 {
689                unsafe {
690                    (api.ze_command_list_destroy)(list);
691                    (api.ze_kernel_destroy)(kernel);
692                    (api.ze_module_destroy)(module);
693                }
694                return Err(BackendError::DeviceError(format!(
695                    "zeCommandListAppendLaunchKernel failed: 0x{rc:08x}"
696                )));
697            }
698
699            // ── 8. Close + execute + wait ──
700            let rc = unsafe { (api.ze_command_list_close)(list) };
701            if rc != 0 {
702                unsafe {
703                    (api.ze_command_list_destroy)(list);
704                    (api.ze_kernel_destroy)(kernel);
705                    (api.ze_module_destroy)(module);
706                }
707                return Err(BackendError::DeviceError(format!(
708                    "zeCommandListClose failed: 0x{rc:08x}"
709                )));
710            }
711
712            let rc = unsafe { (api.ze_command_queue_execute_command_lists)(queue, 1, &list, 0) };
713            if rc != 0 {
714                unsafe {
715                    (api.ze_command_list_destroy)(list);
716                    (api.ze_kernel_destroy)(kernel);
717                    (api.ze_module_destroy)(module);
718                }
719                return Err(BackendError::DeviceError(format!(
720                    "zeCommandQueueExecuteCommandLists failed: 0x{rc:08x}"
721                )));
722            }
723
724            let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
725            if rc != 0 {
726                unsafe {
727                    (api.ze_command_list_destroy)(list);
728                    (api.ze_kernel_destroy)(kernel);
729                    (api.ze_module_destroy)(module);
730                }
731                return Err(BackendError::DeviceError(format!(
732                    "zeCommandQueueSynchronize failed: 0x{rc:08x}"
733                )));
734            }
735
736            // ── 9. Clean up ──
737            unsafe {
738                (api.ze_command_list_destroy)(list);
739                (api.ze_kernel_destroy)(kernel);
740                (api.ze_module_destroy)(module);
741            }
742
743            Ok(())
744        }
745
746        #[cfg(not(any(target_os = "linux", target_os = "windows")))]
747        {
748            let _ = (spv_words, args, workgroups);
749            Err(BackendError::DeviceError(
750                "Level Zero requires Linux or Windows".into(),
751            ))
752        }
753    }
754
755    fn dispatch_unary(
756        &self,
757        op: UnaryOp,
758        input_ptr: u64,
759        output_ptr: u64,
760        n: usize,
761    ) -> BackendResult<()> {
762        let spv = crate::spirv::unary_compute_shader(op);
763        let args = [
764            KernelArg::Buffer(input_ptr),
765            KernelArg::Buffer(output_ptr),
766            KernelArg::U32(n as u32),
767        ];
768        self.run_kernel(&spv, &args, (n as u32).div_ceil(WORKGROUP_SIZE))
769    }
770
771    fn dispatch_binary(
772        &self,
773        op: BinaryOp,
774        a_ptr: u64,
775        b_ptr: u64,
776        output_ptr: u64,
777        n: usize,
778    ) -> BackendResult<()> {
779        let spv = crate::spirv::binary_compute_shader(op);
780        let args = [
781            KernelArg::Buffer(a_ptr),
782            KernelArg::Buffer(b_ptr),
783            KernelArg::Buffer(output_ptr),
784            KernelArg::U32(n as u32),
785        ];
786        self.run_kernel(&spv, &args, (n as u32).div_ceil(WORKGROUP_SIZE))
787    }
788
789    fn dispatch_reduce(
790        &self,
791        op: ReduceOp,
792        input_ptr: u64,
793        output_ptr: u64,
794        shape: &[usize],
795        axis: usize,
796    ) -> BackendResult<()> {
797        let outer_size: usize = shape[..axis].iter().product::<usize>().max(1);
798        let reduce_size = shape[axis];
799        let inner_size: usize = shape[axis + 1..].iter().product::<usize>().max(1);
800
801        let spv = crate::spirv::reduce_compute_shader(op);
802        let total_output = (outer_size * inner_size) as u32;
803        let args = [
804            KernelArg::Buffer(input_ptr),
805            KernelArg::Buffer(output_ptr),
806            KernelArg::U32(outer_size as u32),
807            KernelArg::U32(reduce_size as u32),
808            KernelArg::U32(inner_size as u32),
809        ];
810        self.run_kernel(&spv, &args, total_output.div_ceil(WORKGROUP_SIZE))
811    }
812
813    /// Dispatch a SPIR-V compute kernel with a 3D work group count.
814    ///
815    /// Like [`run_kernel`](Self::run_kernel) but supports 3D dispatch via
816    /// `(group_count_x, group_count_y, group_count_z)`.
817    fn run_kernel_3d(
818        &self,
819        spv_words: &[u32],
820        args: &[KernelArg],
821        workgroups_x: u32,
822        workgroups_y: u32,
823        workgroups_z: u32,
824    ) -> BackendResult<()> {
825        #[cfg(any(target_os = "linux", target_os = "windows"))]
826        {
827            use std::ffi::c_void;
828
829            use crate::device::{
830                ZE_MODULE_FORMAT_IL_SPIRV, ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
831                ZE_STRUCTURE_TYPE_KERNEL_DESC, ZE_STRUCTURE_TYPE_MODULE_DESC, ZeCommandListDesc,
832                ZeGroupCount, ZeKernelDesc, ZeKernelHandle, ZeModuleDesc, ZeModuleHandle,
833            };
834
835            let device = self.device.as_ref().ok_or(BackendError::NotInitialized)?;
836            let memory = self.memory()?;
837            let api = &device.api;
838            let context = device.context;
839            let dev_handle = device.device;
840            let queue = device.queue;
841
842            // ── 1. SPIR-V words → bytes ──
843            let spv_bytes: Vec<u8> = spv_words.iter().flat_map(|w| w.to_ne_bytes()).collect();
844
845            // ── 2. Create module ──
846            let module_desc = ZeModuleDesc {
847                stype: ZE_STRUCTURE_TYPE_MODULE_DESC,
848                p_next: std::ptr::null(),
849                format: ZE_MODULE_FORMAT_IL_SPIRV,
850                input_size: spv_bytes.len(),
851                p_input_module: spv_bytes.as_ptr(),
852                p_build_flags: std::ptr::null(),
853                p_constants: std::ptr::null(),
854            };
855            let mut module: ZeModuleHandle = std::ptr::null_mut();
856            let rc = unsafe {
857                (api.ze_module_create)(
858                    context,
859                    dev_handle,
860                    &module_desc,
861                    &mut module as *mut ZeModuleHandle,
862                    std::ptr::null_mut(),
863                )
864            };
865            if rc != 0 {
866                return Err(BackendError::DeviceError(format!(
867                    "zeModuleCreate failed: 0x{rc:08x}"
868                )));
869            }
870
871            // ── 3. Create kernel ──
872            let kernel_name = b"main\0";
873            let kernel_desc = ZeKernelDesc {
874                stype: ZE_STRUCTURE_TYPE_KERNEL_DESC,
875                p_next: std::ptr::null(),
876                flags: 0,
877                p_kernel_name: kernel_name.as_ptr(),
878            };
879            let mut kernel: ZeKernelHandle = std::ptr::null_mut();
880            let rc = unsafe {
881                (api.ze_kernel_create)(module, &kernel_desc, &mut kernel as *mut ZeKernelHandle)
882            };
883            if rc != 0 {
884                unsafe { (api.ze_module_destroy)(module) };
885                return Err(BackendError::DeviceError(format!(
886                    "zeKernelCreate failed: 0x{rc:08x}"
887                )));
888            }
889
890            // ── 4. Set group size ──
891            let rc = unsafe { (api.ze_kernel_set_group_size)(kernel, WORKGROUP_SIZE, 1, 1) };
892            if rc != 0 {
893                unsafe {
894                    (api.ze_kernel_destroy)(kernel);
895                    (api.ze_module_destroy)(module);
896                }
897                return Err(BackendError::DeviceError(format!(
898                    "zeKernelSetGroupSize failed: 0x{rc:08x}"
899                )));
900            }
901
902            // ── 5. Set kernel arguments ──
903            for (idx, arg) in args.iter().enumerate() {
904                let rc = match arg {
905                    KernelArg::Buffer(handle) => {
906                        let dev_ptr = memory.device_ptr(*handle).map_err(|e| {
907                            unsafe {
908                                (api.ze_kernel_destroy)(kernel);
909                                (api.ze_module_destroy)(module);
910                            }
911                            BackendError::from(e)
912                        })?;
913                        unsafe {
914                            (api.ze_kernel_set_argument_value)(
915                                kernel,
916                                idx as u32,
917                                std::mem::size_of::<*mut c_void>(),
918                                &dev_ptr as *const *mut c_void as *const c_void,
919                            )
920                        }
921                    }
922                    KernelArg::U32(val) => unsafe {
923                        (api.ze_kernel_set_argument_value)(
924                            kernel,
925                            idx as u32,
926                            std::mem::size_of::<u32>(),
927                            val as *const u32 as *const c_void,
928                        )
929                    },
930                    KernelArg::F32(val) => unsafe {
931                        (api.ze_kernel_set_argument_value)(
932                            kernel,
933                            idx as u32,
934                            std::mem::size_of::<f32>(),
935                            val as *const f32 as *const c_void,
936                        )
937                    },
938                };
939                if rc != 0 {
940                    unsafe {
941                        (api.ze_kernel_destroy)(kernel);
942                        (api.ze_module_destroy)(module);
943                    }
944                    return Err(BackendError::DeviceError(format!(
945                        "zeKernelSetArgumentValue(arg={idx}) failed: 0x{rc:08x}"
946                    )));
947                }
948            }
949
950            // ── 6. Create command list ──
951            let list_desc = ZeCommandListDesc {
952                stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
953                p_next: std::ptr::null(),
954                command_queue_group_ordinal: 0,
955                flags: 0,
956            };
957            let mut list = std::ptr::null_mut();
958            let rc =
959                unsafe { (api.ze_command_list_create)(context, dev_handle, &list_desc, &mut list) };
960            if rc != 0 {
961                unsafe {
962                    (api.ze_kernel_destroy)(kernel);
963                    (api.ze_module_destroy)(module);
964                }
965                return Err(BackendError::DeviceError(format!(
966                    "zeCommandListCreate failed: 0x{rc:08x}"
967                )));
968            }
969
970            // ── 7. Append launch kernel (3D) ──
971            let group_count = ZeGroupCount {
972                group_count_x: workgroups_x,
973                group_count_y: workgroups_y,
974                group_count_z: workgroups_z,
975            };
976            let rc = unsafe {
977                (api.ze_command_list_append_launch_kernel)(
978                    list,
979                    kernel,
980                    &group_count,
981                    0,
982                    0,
983                    std::ptr::null(),
984                )
985            };
986            if rc != 0 {
987                unsafe {
988                    (api.ze_command_list_destroy)(list);
989                    (api.ze_kernel_destroy)(kernel);
990                    (api.ze_module_destroy)(module);
991                }
992                return Err(BackendError::DeviceError(format!(
993                    "zeCommandListAppendLaunchKernel failed: 0x{rc:08x}"
994                )));
995            }
996
997            // ── 8. Close + execute + wait ──
998            let rc = unsafe { (api.ze_command_list_close)(list) };
999            if rc != 0 {
1000                unsafe {
1001                    (api.ze_command_list_destroy)(list);
1002                    (api.ze_kernel_destroy)(kernel);
1003                    (api.ze_module_destroy)(module);
1004                }
1005                return Err(BackendError::DeviceError(format!(
1006                    "zeCommandListClose failed: 0x{rc:08x}"
1007                )));
1008            }
1009
1010            let rc = unsafe { (api.ze_command_queue_execute_command_lists)(queue, 1, &list, 0) };
1011            if rc != 0 {
1012                unsafe {
1013                    (api.ze_command_list_destroy)(list);
1014                    (api.ze_kernel_destroy)(kernel);
1015                    (api.ze_module_destroy)(module);
1016                }
1017                return Err(BackendError::DeviceError(format!(
1018                    "zeCommandQueueExecuteCommandLists failed: 0x{rc:08x}"
1019                )));
1020            }
1021
1022            let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
1023            if rc != 0 {
1024                unsafe {
1025                    (api.ze_command_list_destroy)(list);
1026                    (api.ze_kernel_destroy)(kernel);
1027                    (api.ze_module_destroy)(module);
1028                }
1029                return Err(BackendError::DeviceError(format!(
1030                    "zeCommandQueueSynchronize failed: 0x{rc:08x}"
1031                )));
1032            }
1033
1034            // ── 9. Clean up ──
1035            unsafe {
1036                (api.ze_command_list_destroy)(list);
1037                (api.ze_kernel_destroy)(kernel);
1038                (api.ze_module_destroy)(module);
1039            }
1040
1041            Ok(())
1042        }
1043
1044        #[cfg(not(any(target_os = "linux", target_os = "windows")))]
1045        {
1046            let _ = (spv_words, args, workgroups_x, workgroups_y, workgroups_z);
1047            Err(BackendError::DeviceError(
1048                "Level Zero requires Linux or Windows".into(),
1049            ))
1050        }
1051    }
1052
1053    #[allow(clippy::too_many_arguments)]
1054    fn dispatch_batched_gemm(
1055        &self,
1056        m: usize,
1057        n: usize,
1058        k: usize,
1059        alpha: f32,
1060        a_ptr: u64,
1061        b_ptr: u64,
1062        beta: f32,
1063        c_ptr: u64,
1064        batch_count: usize,
1065        stride_a: usize,
1066        stride_b: usize,
1067        stride_c: usize,
1068    ) -> BackendResult<()> {
1069        let spv = crate::spirv::batched_gemm_compute_shader();
1070        let total_per_batch = (m * n) as u32;
1071        let workgroups_x = total_per_batch.div_ceil(WORKGROUP_SIZE);
1072        let args = [
1073            KernelArg::Buffer(a_ptr),
1074            KernelArg::Buffer(b_ptr),
1075            KernelArg::Buffer(c_ptr),
1076            KernelArg::U32(m as u32),
1077            KernelArg::U32(n as u32),
1078            KernelArg::U32(k as u32),
1079            KernelArg::F32(alpha),
1080            KernelArg::F32(beta),
1081            KernelArg::U32(batch_count as u32),
1082            KernelArg::U32(stride_a as u32),
1083            KernelArg::U32(stride_b as u32),
1084            KernelArg::U32(stride_c as u32),
1085        ];
1086        self.run_kernel_3d(&spv, &args, workgroups_x, 1, batch_count as u32)
1087    }
1088
1089    #[allow(clippy::too_many_arguments)]
1090    fn dispatch_gemm(
1091        &self,
1092        m: usize,
1093        n: usize,
1094        k: usize,
1095        alpha: f32,
1096        a_ptr: u64,
1097        b_ptr: u64,
1098        beta: f32,
1099        c_ptr: u64,
1100    ) -> BackendResult<()> {
1101        let spv = crate::spirv::gemm_compute_shader();
1102        let total = (m * n) as u32;
1103        let args = [
1104            KernelArg::Buffer(a_ptr),
1105            KernelArg::Buffer(b_ptr),
1106            KernelArg::Buffer(c_ptr),
1107            KernelArg::U32(m as u32),
1108            KernelArg::U32(n as u32),
1109            KernelArg::U32(k as u32),
1110            KernelArg::F32(alpha),
1111            KernelArg::F32(beta),
1112        ];
1113        self.run_kernel(&spv, &args, total.div_ceil(WORKGROUP_SIZE))
1114    }
1115}
1116
1117// ─── Tests ───────────────────────────────────────────────────────────────────
1118
1119#[cfg(test)]
1120mod tests {
1121    use super::*;
1122    use oxicuda_backend::{BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp};
1123
1124    // ── Construction ──────────────────────────────────────────────────────────
1125
1126    #[test]
1127    fn level_zero_backend_new_uninitialized() {
1128        let b = LevelZeroBackend::new();
1129        assert!(!b.is_initialized());
1130    }
1131
1132    #[test]
1133    fn level_zero_backend_name() {
1134        let b = LevelZeroBackend::new();
1135        assert_eq!(b.name(), "level-zero");
1136    }
1137
1138    #[test]
1139    fn level_zero_backend_default() {
1140        let b = LevelZeroBackend::default();
1141        assert!(!b.is_initialized());
1142        assert_eq!(b.name(), "level-zero");
1143    }
1144
1145    #[test]
1146    fn backend_debug_impl() {
1147        let b = LevelZeroBackend::new();
1148        let s = format!("{b:?}");
1149        assert!(s.contains("LevelZeroBackend"));
1150    }
1151
1152    // ── Object-safety smoke test ──────────────────────────────────────────────
1153
1154    #[test]
1155    fn backend_object_safe() {
1156        let b: Box<dyn ComputeBackend> = Box::new(LevelZeroBackend::new());
1157        assert_eq!(b.name(), "level-zero");
1158    }
1159
1160    // ── Not-initialized guards ────────────────────────────────────────────────
1161
1162    #[test]
1163    fn backend_not_initialized_gemm() {
1164        let b = LevelZeroBackend::new();
1165        let result = b.gemm(
1166            BackendTranspose::NoTrans,
1167            BackendTranspose::NoTrans,
1168            4,
1169            4,
1170            4,
1171            1.0,
1172            0,
1173            4,
1174            0,
1175            4,
1176            0.0,
1177            0,
1178            4,
1179        );
1180        assert_eq!(result, Err(BackendError::NotInitialized));
1181    }
1182
1183    #[test]
1184    fn backend_not_initialized_batched_gemm() {
1185        let b = LevelZeroBackend::new();
1186        let result = b.batched_gemm(
1187            BackendTranspose::NoTrans,
1188            BackendTranspose::NoTrans,
1189            4,
1190            4,
1191            4,
1192            1.0,
1193            0,
1194            4,
1195            16,
1196            0,
1197            4,
1198            16,
1199            0.0,
1200            0,
1201            4,
1202            16,
1203            2,
1204        );
1205        assert_eq!(result, Err(BackendError::NotInitialized));
1206    }
1207
1208    #[test]
1209    fn backend_not_initialized_alloc() {
1210        let b = LevelZeroBackend::new();
1211        assert_eq!(b.alloc(1024), Err(BackendError::NotInitialized));
1212    }
1213
1214    #[test]
1215    fn backend_not_initialized_synchronize() {
1216        let b = LevelZeroBackend::new();
1217        assert_eq!(b.synchronize(), Err(BackendError::NotInitialized));
1218    }
1219
1220    #[test]
1221    fn backend_not_initialized_free() {
1222        let b = LevelZeroBackend::new();
1223        assert_eq!(b.free(1), Err(BackendError::NotInitialized));
1224    }
1225
1226    #[test]
1227    fn backend_not_initialized_copy_htod() {
1228        let b = LevelZeroBackend::new();
1229        assert_eq!(b.copy_htod(1, b"hello"), Err(BackendError::NotInitialized));
1230    }
1231
1232    #[test]
1233    fn backend_not_initialized_copy_dtoh() {
1234        let b = LevelZeroBackend::new();
1235        let mut buf = [0u8; 4];
1236        assert_eq!(b.copy_dtoh(&mut buf, 1), Err(BackendError::NotInitialized));
1237    }
1238
1239    // ── Helper: try to get an initialised backend (skip if no GPU or no loader) ─
1240
1241    fn try_init() -> Option<LevelZeroBackend> {
1242        let mut b = LevelZeroBackend::new();
1243        match b.init() {
1244            Ok(()) => Some(b),
1245            Err(_) => None,
1246        }
1247    }
1248
1249    // ── Graceful init failure ─────────────────────────────────────────────────
1250
1251    #[test]
1252    fn init_graceful_failure() {
1253        // Verify that init() returns a Result and never panics.
1254        let mut b = LevelZeroBackend::new();
1255        let _result = b.init();
1256        // Ok or Err — both are acceptable.
1257    }
1258
1259    // ── Zero-size / trivial-OK paths (post-init) ──────────────────────────────
1260
1261    #[test]
1262    fn alloc_zero_bytes_error() {
1263        let Some(b) = try_init() else {
1264            return;
1265        };
1266        assert_eq!(
1267            b.alloc(0),
1268            Err(BackendError::InvalidArgument(
1269                "cannot allocate 0 bytes".into()
1270            ))
1271        );
1272    }
1273
1274    #[test]
1275    fn copy_htod_empty_noop() {
1276        let Some(b) = try_init() else {
1277            return;
1278        };
1279        assert_eq!(b.copy_htod(0, &[]), Ok(()));
1280    }
1281
1282    #[test]
1283    fn copy_dtoh_empty_noop() {
1284        let Some(b) = try_init() else {
1285            return;
1286        };
1287        assert_eq!(b.copy_dtoh(&mut [], 0), Ok(()));
1288    }
1289
1290    #[test]
1291    fn gemm_zero_dims_noop() {
1292        let Some(b) = try_init() else {
1293            return;
1294        };
1295        assert_eq!(
1296            b.gemm(
1297                BackendTranspose::NoTrans,
1298                BackendTranspose::NoTrans,
1299                0,
1300                0,
1301                0,
1302                1.0,
1303                0,
1304                1,
1305                0,
1306                1,
1307                0.0,
1308                0,
1309                1
1310            ),
1311            Ok(())
1312        );
1313    }
1314
1315    #[test]
1316    fn batched_gemm_zero_batch_noop() {
1317        let Some(b) = try_init() else {
1318            return;
1319        };
1320        assert_eq!(
1321            b.batched_gemm(
1322                BackendTranspose::NoTrans,
1323                BackendTranspose::NoTrans,
1324                4,
1325                4,
1326                4,
1327                1.0,
1328                0,
1329                4,
1330                16,
1331                0,
1332                4,
1333                16,
1334                0.0,
1335                0,
1336                4,
1337                16,
1338                0,
1339            ),
1340            Ok(())
1341        );
1342    }
1343
1344    #[test]
1345    fn batched_gemm_zero_dims_noop() {
1346        let Some(b) = try_init() else {
1347            return;
1348        };
1349        assert_eq!(
1350            b.batched_gemm(
1351                BackendTranspose::NoTrans,
1352                BackendTranspose::NoTrans,
1353                0,
1354                0,
1355                0,
1356                1.0,
1357                0,
1358                1,
1359                0,
1360                0,
1361                1,
1362                0,
1363                0.0,
1364                0,
1365                1,
1366                0,
1367                3,
1368            ),
1369            Ok(())
1370        );
1371    }
1372
1373    #[test]
1374    fn unary_zero_n_noop() {
1375        let Some(b) = try_init() else {
1376            return;
1377        };
1378        assert_eq!(b.unary(UnaryOp::Relu, 0, 0, 0), Ok(()));
1379    }
1380
1381    #[test]
1382    fn binary_zero_n_noop() {
1383        let Some(b) = try_init() else {
1384            return;
1385        };
1386        assert_eq!(b.binary(BinaryOp::Add, 0, 0, 0, 0), Ok(()));
1387    }
1388
1389    #[test]
1390    fn synchronize_after_init() {
1391        let Some(b) = try_init() else {
1392            return;
1393        };
1394        assert_eq!(b.synchronize(), Ok(()));
1395    }
1396
1397    // ── Argument validation (post-init) ───────────────────────────────────────
1398
1399    #[test]
1400    fn reduce_empty_shape_error() {
1401        let Some(b) = try_init() else {
1402            return;
1403        };
1404        assert_eq!(
1405            b.reduce(ReduceOp::Sum, 0, 0, &[], 0),
1406            Err(BackendError::InvalidArgument(
1407                "shape must not be empty".into()
1408            ))
1409        );
1410    }
1411
1412    #[test]
1413    fn reduce_axis_out_of_bounds_error() {
1414        let Some(b) = try_init() else {
1415            return;
1416        };
1417        assert_eq!(
1418            b.reduce(ReduceOp::Sum, 0, 0, &[4, 4], 5),
1419            Err(BackendError::InvalidArgument(
1420                "axis 5 is out of bounds for shape of length 2".into()
1421            ))
1422        );
1423    }
1424
1425    #[test]
1426    fn attention_zero_seq_error() {
1427        let Some(b) = try_init() else {
1428            return;
1429        };
1430        assert_eq!(
1431            b.attention(0, 0, 0, 0, 1, 1, 0, 8, 64, 0.125, false),
1432            Err(BackendError::InvalidArgument(
1433                "seq_q, seq_kv, and head_dim must all be > 0".into()
1434            ))
1435        );
1436    }
1437
1438    #[test]
1439    fn attention_invalid_scale_error() {
1440        let Some(b) = try_init() else {
1441            return;
1442        };
1443        assert_eq!(
1444            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, 0.0, false),
1445            Err(BackendError::InvalidArgument(
1446                "scale must be a positive finite number, got 0".into()
1447            ))
1448        );
1449        assert_eq!(
1450            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, -1.0, false),
1451            Err(BackendError::InvalidArgument(
1452                "scale must be a positive finite number, got -1".into()
1453            ))
1454        );
1455        assert!(
1456            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, f64::INFINITY, false)
1457                .is_err()
1458        );
1459    }
1460
1461    #[test]
1462    fn conv2d_wrong_input_rank() {
1463        let Some(b) = try_init() else {
1464            return;
1465        };
1466        assert_eq!(
1467            b.conv2d_forward(
1468                0,
1469                &[1, 3, 32],
1470                0,
1471                &[16, 3, 3, 3],
1472                0,
1473                &[1, 16, 30, 30],
1474                &[1, 1],
1475                &[0, 0]
1476            ),
1477            Err(BackendError::InvalidArgument(
1478                "input_shape must have 4 elements (NCHW)".into()
1479            ))
1480        );
1481    }
1482
1483    #[test]
1484    fn conv2d_wrong_filter_rank() {
1485        let Some(b) = try_init() else {
1486            return;
1487        };
1488        assert_eq!(
1489            b.conv2d_forward(
1490                0,
1491                &[1, 3, 32, 32],
1492                0,
1493                &[16, 3, 3],
1494                0,
1495                &[1, 16, 30, 30],
1496                &[1, 1],
1497                &[0, 0]
1498            ),
1499            Err(BackendError::InvalidArgument(
1500                "filter_shape must have 4 elements (KCFHFW)".into()
1501            ))
1502        );
1503    }
1504
1505    // ── Init is idempotent ────────────────────────────────────────────────────
1506
1507    #[test]
1508    fn init_idempotent() {
1509        let Some(mut b) = try_init() else {
1510            return;
1511        };
1512        assert_eq!(b.init(), Ok(()));
1513        assert!(b.is_initialized());
1514    }
1515
1516    // ── alloc/free/copy roundtrip ─────────────────────────────────────────────
1517
1518    #[test]
1519    fn alloc_copy_roundtrip() {
1520        let Some(b) = try_init() else {
1521            return;
1522        };
1523        let src: Vec<u8> = (0u8..64).collect();
1524        let handle = match b.alloc(src.len()) {
1525            Ok(h) => h,
1526            Err(_) => return,
1527        };
1528        b.copy_htod(handle, &src).expect("copy_htod");
1529        let mut dst = vec![0u8; src.len()];
1530        b.copy_dtoh(&mut dst, handle).expect("copy_dtoh");
1531        assert_eq!(src, dst);
1532        b.free(handle).expect("free");
1533    }
1534
1535    // ── Double init is a no-op ────────────────────────────────────────────────
1536
1537    #[test]
1538    fn double_init_is_noop() {
1539        let Some(mut b) = try_init() else {
1540            return;
1541        };
1542        let first = b.is_initialized();
1543        let _ = b.init();
1544        assert_eq!(first, b.is_initialized());
1545    }
1546
1547    // ── alloc and free basic ──────────────────────────────────────────────────
1548
1549    #[test]
1550    fn alloc_and_free_basic() {
1551        let Some(b) = try_init() else {
1552            return;
1553        };
1554        match b.alloc(128) {
1555            Ok(handle) => {
1556                assert!(handle > 0);
1557                b.free(handle).expect("free should succeed");
1558            }
1559            Err(_) => {
1560                // Allocation failure is acceptable in environments without GPU.
1561            }
1562        }
1563    }
1564
1565    // ── Conv2D correctness tests ──────────────────────────────────────────────
1566
1567    /// Helper: allocate device memory, store f32 data, return handle.
1568    fn upload_f32(b: &LevelZeroBackend, data: &[f32]) -> Option<u64> {
1569        let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_ne_bytes()).collect();
1570        let handle = b.alloc(bytes.len()).ok()?;
1571        b.copy_htod(handle, &bytes).ok()?;
1572        Some(handle)
1573    }
1574
1575    /// Helper: download f32 data from device.
1576    fn download_f32(b: &LevelZeroBackend, handle: u64, len: usize) -> Option<Vec<f32>> {
1577        let mut bytes = vec![0u8; len * 4];
1578        b.copy_dtoh(&mut bytes, handle).ok()?;
1579        Some(
1580            bytes
1581                .chunks_exact(4)
1582                .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
1583                .collect(),
1584        )
1585    }
1586
1587    #[test]
1588    fn l0_conv2d_identity_1x1() {
1589        let Some(b) = try_init() else {
1590            return;
1591        };
1592        // 1x1x4x4 input, 1x1x1x1 filter = identity (filter=[1.0])
1593        let input: Vec<f32> = (0..16).map(|i| i as f32).collect();
1594        let filter = vec![1.0f32];
1595        let output_len = 16;
1596
1597        let Some(in_h) = upload_f32(&b, &input) else {
1598            return;
1599        };
1600        let Some(flt_h) = upload_f32(&b, &filter) else {
1601            return;
1602        };
1603        let Some(out_h) = b.alloc(output_len * 4).ok() else {
1604            return;
1605        };
1606
1607        let result = b.conv2d_forward(
1608            in_h,
1609            &[1, 1, 4, 4],
1610            flt_h,
1611            &[1, 1, 1, 1],
1612            out_h,
1613            &[1, 1, 4, 4],
1614            &[1, 1],
1615            &[0, 0],
1616        );
1617        assert!(result.is_ok(), "conv2d_forward failed: {result:?}");
1618
1619        if let Some(out) = download_f32(&b, out_h, output_len) {
1620            for (i, &val) in out.iter().enumerate() {
1621                assert!(
1622                    (val - input[i]).abs() < 1e-5,
1623                    "mismatch at {i}: expected {}, got {val}",
1624                    input[i]
1625                );
1626            }
1627        }
1628
1629        let _ = b.free(in_h);
1630        let _ = b.free(flt_h);
1631        let _ = b.free(out_h);
1632    }
1633
1634    #[test]
1635    fn l0_conv2d_3x3_basic() {
1636        let Some(b) = try_init() else {
1637            return;
1638        };
1639        // 1x1x4x4 input, 1x1x3x3 filter, stride=1, pad=0 → 1x1x2x2 output
1640        let input: Vec<f32> = (0..16).map(|i| i as f32).collect();
1641        // All-ones 3x3 filter: output[oy,ox] = sum of 3x3 window
1642        let filter = vec![1.0f32; 9];
1643        let output_len = 4;
1644
1645        let Some(in_h) = upload_f32(&b, &input) else {
1646            return;
1647        };
1648        let Some(flt_h) = upload_f32(&b, &filter) else {
1649            return;
1650        };
1651        let Some(out_h) = b.alloc(output_len * 4).ok() else {
1652            return;
1653        };
1654
1655        let result = b.conv2d_forward(
1656            in_h,
1657            &[1, 1, 4, 4],
1658            flt_h,
1659            &[1, 1, 3, 3],
1660            out_h,
1661            &[1, 1, 2, 2],
1662            &[1, 1],
1663            &[0, 0],
1664        );
1665        assert!(result.is_ok());
1666
1667        // Expected:
1668        // out[0,0] = 0+1+2+4+5+6+8+9+10 = 45
1669        // out[0,1] = 1+2+3+5+6+7+9+10+11 = 54
1670        // out[1,0] = 4+5+6+8+9+10+12+13+14 = 81
1671        // out[1,1] = 5+6+7+9+10+11+13+14+15 = 90
1672        let expected = [45.0f32, 54.0, 81.0, 90.0];
1673        if let Some(out) = download_f32(&b, out_h, output_len) {
1674            for (i, &val) in out.iter().enumerate() {
1675                assert!(
1676                    (val - expected[i]).abs() < 1e-4,
1677                    "mismatch at {i}: expected {}, got {val}",
1678                    expected[i]
1679                );
1680            }
1681        }
1682
1683        let _ = b.free(in_h);
1684        let _ = b.free(flt_h);
1685        let _ = b.free(out_h);
1686    }
1687
1688    #[test]
1689    fn l0_conv2d_with_padding() {
1690        let Some(b) = try_init() else {
1691            return;
1692        };
1693        // 1x1x3x3 input, 1x1x3x3 filter (all ones), stride=1, pad=1 → 1x1x3x3 output
1694        let input: Vec<f32> = (1..=9).map(|i| i as f32).collect();
1695        let filter = vec![1.0f32; 9];
1696        let output_len = 9;
1697
1698        let Some(in_h) = upload_f32(&b, &input) else {
1699            return;
1700        };
1701        let Some(flt_h) = upload_f32(&b, &filter) else {
1702            return;
1703        };
1704        let Some(out_h) = b.alloc(output_len * 4).ok() else {
1705            return;
1706        };
1707
1708        let result = b.conv2d_forward(
1709            in_h,
1710            &[1, 1, 3, 3],
1711            flt_h,
1712            &[1, 1, 3, 3],
1713            out_h,
1714            &[1, 1, 3, 3],
1715            &[1, 1],
1716            &[1, 1],
1717        );
1718        assert!(result.is_ok());
1719
1720        // Center element: sum of all 9 = 45
1721        if let Some(out) = download_f32(&b, out_h, output_len) {
1722            assert!(
1723                (out[4] - 45.0).abs() < 1e-4,
1724                "center expected 45, got {}",
1725                out[4]
1726            );
1727            // Corner [0,0]: sum of [1,2,4,5] = 12
1728            assert!(
1729                (out[0] - 12.0).abs() < 1e-4,
1730                "corner expected 12, got {}",
1731                out[0]
1732            );
1733        }
1734
1735        let _ = b.free(in_h);
1736        let _ = b.free(flt_h);
1737        let _ = b.free(out_h);
1738    }
1739
1740    // ── Attention correctness tests ───────────────────────────────────────────
1741
1742    #[test]
1743    fn l0_attention_uniform() {
1744        let Some(b) = try_init() else {
1745            return;
1746        };
1747        // batch=1, heads=1, seq_q=2, seq_kv=2, head_dim=2
1748        // Q=K=all zeros → uniform attention → O = mean(V)
1749        let seq_q = 2;
1750        let seq_kv = 2;
1751        let head_dim = 2;
1752        let q = vec![0.0f32; seq_q * head_dim];
1753        let k = vec![0.0f32; seq_kv * head_dim];
1754        let v = vec![1.0f32, 2.0, 3.0, 4.0]; // V[0]=[1,2], V[1]=[3,4]
1755        let o_len = seq_q * head_dim;
1756
1757        let Some(q_h) = upload_f32(&b, &q) else {
1758            return;
1759        };
1760        let Some(k_h) = upload_f32(&b, &k) else {
1761            return;
1762        };
1763        let Some(v_h) = upload_f32(&b, &v) else {
1764            return;
1765        };
1766        let Some(o_h) = b.alloc(o_len * 4).ok() else {
1767            return;
1768        };
1769        // Zero out output
1770        let zeros = vec![0u8; o_len * 4];
1771        let _ = b.copy_htod(o_h, &zeros);
1772
1773        let scale = 1.0 / (head_dim as f64).sqrt();
1774        let result = b.attention(
1775            q_h, k_h, v_h, o_h, 1, 1, seq_q, seq_kv, head_dim, scale, false,
1776        );
1777        assert!(result.is_ok(), "attention failed: {result:?}");
1778
1779        // With uniform attention weights, output = mean(V rows)
1780        // mean = [(1+3)/2, (2+4)/2] = [2, 3]
1781        if let Some(out) = download_f32(&b, o_h, o_len) {
1782            // Both query positions should get the same result
1783            for sq_idx in 0..seq_q {
1784                let off = sq_idx * head_dim;
1785                assert!(
1786                    (out[off] - 2.0).abs() < 1e-4,
1787                    "q{sq_idx}[0] expected 2.0, got {}",
1788                    out[off]
1789                );
1790                assert!(
1791                    (out[off + 1] - 3.0).abs() < 1e-4,
1792                    "q{sq_idx}[1] expected 3.0, got {}",
1793                    out[off + 1]
1794                );
1795            }
1796        }
1797
1798        let _ = b.free(q_h);
1799        let _ = b.free(k_h);
1800        let _ = b.free(v_h);
1801        let _ = b.free(o_h);
1802    }
1803
1804    #[test]
1805    fn l0_attention_causal() {
1806        let Some(b) = try_init() else {
1807            return;
1808        };
1809        // batch=1, heads=1, seq_q=2, seq_kv=2, head_dim=2, causal=true
1810        let seq_q = 2;
1811        let seq_kv = 2;
1812        let head_dim = 2;
1813        let q = vec![0.0f32; seq_q * head_dim];
1814        let k = vec![0.0f32; seq_kv * head_dim];
1815        let v = vec![1.0f32, 2.0, 3.0, 4.0];
1816        let o_len = seq_q * head_dim;
1817
1818        let Some(q_h) = upload_f32(&b, &q) else {
1819            return;
1820        };
1821        let Some(k_h) = upload_f32(&b, &k) else {
1822            return;
1823        };
1824        let Some(v_h) = upload_f32(&b, &v) else {
1825            return;
1826        };
1827        let Some(o_h) = b.alloc(o_len * 4).ok() else {
1828            return;
1829        };
1830        let zeros = vec![0u8; o_len * 4];
1831        let _ = b.copy_htod(o_h, &zeros);
1832
1833        let scale = 1.0 / (head_dim as f64).sqrt();
1834        let result = b.attention(
1835            q_h, k_h, v_h, o_h, 1, 1, seq_q, seq_kv, head_dim, scale, true,
1836        );
1837        assert!(result.is_ok());
1838
1839        if let Some(out) = download_f32(&b, o_h, o_len) {
1840            // q=0 (causal: can only attend to k=0): output = V[0] = [1, 2]
1841            assert!(
1842                (out[0] - 1.0).abs() < 1e-4,
1843                "q0[0] expected 1.0, got {}",
1844                out[0]
1845            );
1846            assert!(
1847                (out[1] - 2.0).abs() < 1e-4,
1848                "q0[1] expected 2.0, got {}",
1849                out[1]
1850            );
1851            // q=1 (can attend to k=0,1): output = mean(V) = [2, 3]
1852            assert!(
1853                (out[2] - 2.0).abs() < 1e-4,
1854                "q1[0] expected 2.0, got {}",
1855                out[2]
1856            );
1857            assert!(
1858                (out[3] - 3.0).abs() < 1e-4,
1859                "q1[1] expected 3.0, got {}",
1860                out[3]
1861            );
1862        }
1863
1864        let _ = b.free(q_h);
1865        let _ = b.free(k_h);
1866        let _ = b.free(v_h);
1867        let _ = b.free(o_h);
1868    }
1869
1870    #[test]
1871    fn l0_attention_dominant_key() {
1872        let Some(b) = try_init() else {
1873            return;
1874        };
1875        // One key has a very large dot product → attention should be concentrated on it
1876        let seq_q = 1;
1877        let seq_kv = 3;
1878        let head_dim = 2;
1879        // Q = [10, 0]
1880        let q = vec![10.0f32, 0.0];
1881        // K = [[10, 0], [0, 0], [0, 0]]  → dot(Q,K[0]) = 100, others = 0
1882        let k = vec![10.0f32, 0.0, 0.0, 0.0, 0.0, 0.0];
1883        let v = vec![1.0f32, 0.0, 0.0, 1.0, 0.0, 0.0]; // V[0]=[1,0], V[1]=[0,1], V[2]=[0,0]
1884        let o_len = seq_q * head_dim;
1885
1886        let Some(q_h) = upload_f32(&b, &q) else {
1887            return;
1888        };
1889        let Some(k_h) = upload_f32(&b, &k) else {
1890            return;
1891        };
1892        let Some(v_h) = upload_f32(&b, &v) else {
1893            return;
1894        };
1895        let Some(o_h) = b.alloc(o_len * 4).ok() else {
1896            return;
1897        };
1898        let zeros = vec![0u8; o_len * 4];
1899        let _ = b.copy_htod(o_h, &zeros);
1900
1901        let scale = 1.0;
1902        let result = b.attention(
1903            q_h, k_h, v_h, o_h, 1, 1, seq_q, seq_kv, head_dim, scale, false,
1904        );
1905        assert!(result.is_ok());
1906
1907        if let Some(out) = download_f32(&b, o_h, o_len) {
1908            // With dot=100*scale=100 for K[0] vs 0 for others,
1909            // softmax should heavily favour V[0]=[1,0]
1910            assert!(out[0] > 0.99, "expected output[0] ≈ 1.0, got {}", out[0]);
1911            assert!(out[1] < 0.01, "expected output[1] ≈ 0.0, got {}", out[1]);
1912        }
1913
1914        let _ = b.free(q_h);
1915        let _ = b.free(k_h);
1916        let _ = b.free(v_h);
1917        let _ = b.free(o_h);
1918    }
1919}