Skip to main content

oxicuda_levelzero/
backend.rs

1//! [`LevelZeroBackend`] — the main entry point for the oxicuda-levelzero crate.
2//!
3//! Implements the [`ComputeBackend`] trait from `oxicuda-backend` using
4//! Intel's Level Zero API for GPU compute on Linux and Windows.
5
6use std::sync::Arc;
7
8use oxicuda_backend::{
9    BackendError, BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
10};
11
12use crate::{device::LevelZeroDevice, memory::LevelZeroMemoryManager};
13
14// ─── Backend struct ───────────────────────────────────────────────────────────
15
16/// Intel Level Zero GPU compute backend.
17///
18/// On Linux and Windows this selects the first Intel GPU via the Level Zero
19/// loader library (`libze_loader.so` / `ze_loader.dll`) and allocates device
20/// memory through the Level Zero memory model.
21///
22/// On macOS every operation returns [`BackendError::DeviceError`] wrapping
23/// [`crate::error::LevelZeroError::UnsupportedPlatform`].
24///
25/// # Lifecycle
26///
27/// 1. `LevelZeroBackend::new()` — create an uninitialised backend.
28/// 2. `init()` — load the Level Zero driver and select a GPU.
29/// 3. Use `alloc`, `copy_htod`, compute ops, `copy_dtoh`, `free`.
30/// 4. `synchronize()` — wait for all pending GPU work to finish.
31#[derive(Debug)]
32pub struct LevelZeroBackend {
33    device: Option<Arc<LevelZeroDevice>>,
34    memory: Option<Arc<LevelZeroMemoryManager>>,
35    initialized: bool,
36}
37
38impl LevelZeroBackend {
39    /// Create a new, uninitialised Level Zero backend.
40    pub fn new() -> Self {
41        Self {
42            device: None,
43            memory: None,
44            initialized: false,
45        }
46    }
47
48    /// Return an error if the backend has not been initialised yet.
49    fn check_init(&self) -> BackendResult<()> {
50        if self.initialized {
51            Ok(())
52        } else {
53            Err(BackendError::NotInitialized)
54        }
55    }
56
57    /// Convenience accessor: get the memory manager or return `NotInitialized`.
58    fn memory(&self) -> BackendResult<&Arc<LevelZeroMemoryManager>> {
59        self.memory.as_ref().ok_or(BackendError::NotInitialized)
60    }
61}
62
63impl Default for LevelZeroBackend {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69// ─── ComputeBackend impl ──────────────────────────────────────────────────────
70
71impl ComputeBackend for LevelZeroBackend {
72    fn name(&self) -> &str {
73        "level-zero"
74    }
75
76    fn init(&mut self) -> BackendResult<()> {
77        if self.initialized {
78            return Ok(());
79        }
80        match LevelZeroDevice::new() {
81            Ok(dev) => {
82                let dev = Arc::new(dev);
83                tracing::info!("Level Zero backend initialised on: {}", dev.name());
84                let memory = LevelZeroMemoryManager::new(Arc::clone(&dev));
85                self.device = Some(dev);
86                self.memory = Some(Arc::new(memory));
87                self.initialized = true;
88                Ok(())
89            }
90            Err(e) => Err(BackendError::from(e)),
91        }
92    }
93
94    fn is_initialized(&self) -> bool {
95        self.initialized
96    }
97
98    // ── Compute operations ────────────────────────────────────────────────────
99
100    fn gemm(
101        &self,
102        _trans_a: BackendTranspose,
103        _trans_b: BackendTranspose,
104        m: usize,
105        n: usize,
106        k: usize,
107        _alpha: f64,
108        _a_ptr: u64,
109        _lda: usize,
110        _b_ptr: u64,
111        _ldb: usize,
112        _beta: f64,
113        _c_ptr: u64,
114        _ldc: usize,
115    ) -> BackendResult<()> {
116        self.check_init()?;
117        // Zero-dimension matrices are trivially complete.
118        if m == 0 || n == 0 || k == 0 {
119            return Ok(());
120        }
121        Err(BackendError::Unsupported(
122            "level-zero: gemm not yet wired".into(),
123        ))
124    }
125
126    fn conv2d_forward(
127        &self,
128        _input_ptr: u64,
129        input_shape: &[usize],
130        _filter_ptr: u64,
131        filter_shape: &[usize],
132        _output_ptr: u64,
133        output_shape: &[usize],
134        stride: &[usize],
135        padding: &[usize],
136    ) -> BackendResult<()> {
137        self.check_init()?;
138
139        if input_shape.len() != 4 {
140            return Err(BackendError::InvalidArgument(
141                "input_shape must have 4 elements (NCHW)".into(),
142            ));
143        }
144        if filter_shape.len() != 4 {
145            return Err(BackendError::InvalidArgument(
146                "filter_shape must have 4 elements (KCFHFW)".into(),
147            ));
148        }
149        if output_shape.len() != 4 {
150            return Err(BackendError::InvalidArgument(
151                "output_shape must have 4 elements (NKOhOw)".into(),
152            ));
153        }
154        if stride.len() != 2 {
155            return Err(BackendError::InvalidArgument(
156                "stride must have 2 elements [sh, sw]".into(),
157            ));
158        }
159        if padding.len() != 2 {
160            return Err(BackendError::InvalidArgument(
161                "padding must have 2 elements [ph, pw]".into(),
162            ));
163        }
164
165        Err(BackendError::Unsupported(
166            "level-zero: conv2d not yet wired".into(),
167        ))
168    }
169
170    fn attention(
171        &self,
172        _q_ptr: u64,
173        _k_ptr: u64,
174        _v_ptr: u64,
175        _o_ptr: u64,
176        _batch: usize,
177        _heads: usize,
178        seq_q: usize,
179        seq_kv: usize,
180        head_dim: usize,
181        scale: f64,
182        _causal: bool,
183    ) -> BackendResult<()> {
184        self.check_init()?;
185
186        if seq_q == 0 || seq_kv == 0 || head_dim == 0 {
187            return Err(BackendError::InvalidArgument(
188                "seq_q, seq_kv, and head_dim must all be > 0".into(),
189            ));
190        }
191        if scale <= 0.0 || !scale.is_finite() {
192            return Err(BackendError::InvalidArgument(format!(
193                "scale must be a positive finite number, got {scale}"
194            )));
195        }
196
197        Err(BackendError::Unsupported(
198            "level-zero: attention not yet wired".into(),
199        ))
200    }
201
202    fn reduce(
203        &self,
204        _op: ReduceOp,
205        _input_ptr: u64,
206        _output_ptr: u64,
207        shape: &[usize],
208        axis: usize,
209    ) -> BackendResult<()> {
210        self.check_init()?;
211
212        if shape.is_empty() {
213            return Err(BackendError::InvalidArgument(
214                "shape must not be empty".into(),
215            ));
216        }
217        if axis >= shape.len() {
218            return Err(BackendError::InvalidArgument(format!(
219                "axis {axis} is out of bounds for shape of length {}",
220                shape.len()
221            )));
222        }
223
224        Err(BackendError::Unsupported(
225            "level-zero: reduce not yet wired".into(),
226        ))
227    }
228
229    fn unary(
230        &self,
231        _op: UnaryOp,
232        _input_ptr: u64,
233        _output_ptr: u64,
234        n: usize,
235    ) -> BackendResult<()> {
236        self.check_init()?;
237        if n == 0 {
238            return Ok(());
239        }
240        Err(BackendError::Unsupported(
241            "level-zero: unary not yet wired".into(),
242        ))
243    }
244
245    fn binary(
246        &self,
247        _op: BinaryOp,
248        _a_ptr: u64,
249        _b_ptr: u64,
250        _output_ptr: u64,
251        n: usize,
252    ) -> BackendResult<()> {
253        self.check_init()?;
254        if n == 0 {
255            return Ok(());
256        }
257        Err(BackendError::Unsupported(
258            "level-zero: binary not yet wired".into(),
259        ))
260    }
261
262    // ── Synchronisation ───────────────────────────────────────────────────────
263
264    fn synchronize(&self) -> BackendResult<()> {
265        self.check_init()?;
266
267        #[cfg(any(target_os = "linux", target_os = "windows"))]
268        {
269            if let Some(dev) = &self.device {
270                let api = &dev.api;
271                let queue = dev.queue;
272                // SAFETY: `queue` is a valid command queue handle and the
273                // backend is initialized.  u64::MAX means "wait indefinitely".
274                let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
275                if rc != 0 {
276                    return Err(BackendError::DeviceError(format!(
277                        "zeCommandQueueSynchronize failed: 0x{rc:08x}"
278                    )));
279                }
280            }
281        }
282
283        Ok(())
284    }
285
286    // ── Memory management ─────────────────────────────────────────────────────
287
288    fn alloc(&self, bytes: usize) -> BackendResult<u64> {
289        self.check_init()?;
290        if bytes == 0 {
291            return Err(BackendError::InvalidArgument(
292                "cannot allocate 0 bytes".into(),
293            ));
294        }
295        self.memory()?.alloc(bytes).map_err(BackendError::from)
296    }
297
298    fn free(&self, ptr: u64) -> BackendResult<()> {
299        self.check_init()?;
300        self.memory()?.free(ptr).map_err(BackendError::from)
301    }
302
303    fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
304        self.check_init()?;
305        if src.is_empty() {
306            return Ok(());
307        }
308        self.memory()?
309            .copy_to_device(dst, src)
310            .map_err(BackendError::from)
311    }
312
313    fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
314        self.check_init()?;
315        if dst.is_empty() {
316            return Ok(());
317        }
318        self.memory()?
319            .copy_from_device(dst, src)
320            .map_err(BackendError::from)
321    }
322}
323
324// ─── Tests ───────────────────────────────────────────────────────────────────
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use oxicuda_backend::{BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp};
330
331    // ── Construction ──────────────────────────────────────────────────────────
332
333    #[test]
334    fn level_zero_backend_new_uninitialized() {
335        let b = LevelZeroBackend::new();
336        assert!(!b.is_initialized());
337    }
338
339    #[test]
340    fn level_zero_backend_name() {
341        let b = LevelZeroBackend::new();
342        assert_eq!(b.name(), "level-zero");
343    }
344
345    #[test]
346    fn level_zero_backend_default() {
347        let b = LevelZeroBackend::default();
348        assert!(!b.is_initialized());
349        assert_eq!(b.name(), "level-zero");
350    }
351
352    #[test]
353    fn backend_debug_impl() {
354        let b = LevelZeroBackend::new();
355        let s = format!("{b:?}");
356        assert!(s.contains("LevelZeroBackend"));
357    }
358
359    // ── Object-safety smoke test ──────────────────────────────────────────────
360
361    #[test]
362    fn backend_object_safe() {
363        let b: Box<dyn ComputeBackend> = Box::new(LevelZeroBackend::new());
364        assert_eq!(b.name(), "level-zero");
365    }
366
367    // ── Not-initialized guards ────────────────────────────────────────────────
368
369    #[test]
370    fn backend_not_initialized_gemm() {
371        let b = LevelZeroBackend::new();
372        let result = b.gemm(
373            BackendTranspose::NoTrans,
374            BackendTranspose::NoTrans,
375            4,
376            4,
377            4,
378            1.0,
379            0,
380            4,
381            0,
382            4,
383            0.0,
384            0,
385            4,
386        );
387        assert_eq!(result, Err(BackendError::NotInitialized));
388    }
389
390    #[test]
391    fn backend_not_initialized_alloc() {
392        let b = LevelZeroBackend::new();
393        assert_eq!(b.alloc(1024), Err(BackendError::NotInitialized));
394    }
395
396    #[test]
397    fn backend_not_initialized_synchronize() {
398        let b = LevelZeroBackend::new();
399        assert_eq!(b.synchronize(), Err(BackendError::NotInitialized));
400    }
401
402    #[test]
403    fn backend_not_initialized_free() {
404        let b = LevelZeroBackend::new();
405        assert_eq!(b.free(1), Err(BackendError::NotInitialized));
406    }
407
408    #[test]
409    fn backend_not_initialized_copy_htod() {
410        let b = LevelZeroBackend::new();
411        assert_eq!(b.copy_htod(1, b"hello"), Err(BackendError::NotInitialized));
412    }
413
414    #[test]
415    fn backend_not_initialized_copy_dtoh() {
416        let b = LevelZeroBackend::new();
417        let mut buf = [0u8; 4];
418        assert_eq!(b.copy_dtoh(&mut buf, 1), Err(BackendError::NotInitialized));
419    }
420
421    // ── Helper: try to get an initialised backend (skip if no GPU or no loader) ─
422
423    fn try_init() -> Option<LevelZeroBackend> {
424        let mut b = LevelZeroBackend::new();
425        match b.init() {
426            Ok(()) => Some(b),
427            Err(_) => None,
428        }
429    }
430
431    // ── Graceful init failure ─────────────────────────────────────────────────
432
433    #[test]
434    fn init_graceful_failure() {
435        // Verify that init() returns a Result and never panics.
436        let mut b = LevelZeroBackend::new();
437        let _result = b.init();
438        // Ok or Err — both are acceptable.
439    }
440
441    // ── Zero-size / trivial-OK paths (post-init) ──────────────────────────────
442
443    #[test]
444    fn alloc_zero_bytes_error() {
445        let Some(b) = try_init() else {
446            return;
447        };
448        assert_eq!(
449            b.alloc(0),
450            Err(BackendError::InvalidArgument(
451                "cannot allocate 0 bytes".into()
452            ))
453        );
454    }
455
456    #[test]
457    fn copy_htod_empty_noop() {
458        let Some(b) = try_init() else {
459            return;
460        };
461        assert_eq!(b.copy_htod(0, &[]), Ok(()));
462    }
463
464    #[test]
465    fn copy_dtoh_empty_noop() {
466        let Some(b) = try_init() else {
467            return;
468        };
469        assert_eq!(b.copy_dtoh(&mut [], 0), Ok(()));
470    }
471
472    #[test]
473    fn gemm_zero_dims_noop() {
474        let Some(b) = try_init() else {
475            return;
476        };
477        assert_eq!(
478            b.gemm(
479                BackendTranspose::NoTrans,
480                BackendTranspose::NoTrans,
481                0,
482                0,
483                0,
484                1.0,
485                0,
486                1,
487                0,
488                1,
489                0.0,
490                0,
491                1
492            ),
493            Ok(())
494        );
495    }
496
497    #[test]
498    fn unary_zero_n_noop() {
499        let Some(b) = try_init() else {
500            return;
501        };
502        assert_eq!(b.unary(UnaryOp::Relu, 0, 0, 0), Ok(()));
503    }
504
505    #[test]
506    fn binary_zero_n_noop() {
507        let Some(b) = try_init() else {
508            return;
509        };
510        assert_eq!(b.binary(BinaryOp::Add, 0, 0, 0, 0), Ok(()));
511    }
512
513    #[test]
514    fn synchronize_after_init() {
515        let Some(b) = try_init() else {
516            return;
517        };
518        assert_eq!(b.synchronize(), Ok(()));
519    }
520
521    // ── Argument validation (post-init) ───────────────────────────────────────
522
523    #[test]
524    fn reduce_empty_shape_error() {
525        let Some(b) = try_init() else {
526            return;
527        };
528        assert_eq!(
529            b.reduce(ReduceOp::Sum, 0, 0, &[], 0),
530            Err(BackendError::InvalidArgument(
531                "shape must not be empty".into()
532            ))
533        );
534    }
535
536    #[test]
537    fn reduce_axis_out_of_bounds_error() {
538        let Some(b) = try_init() else {
539            return;
540        };
541        assert_eq!(
542            b.reduce(ReduceOp::Sum, 0, 0, &[4, 4], 5),
543            Err(BackendError::InvalidArgument(
544                "axis 5 is out of bounds for shape of length 2".into()
545            ))
546        );
547    }
548
549    #[test]
550    fn attention_zero_seq_error() {
551        let Some(b) = try_init() else {
552            return;
553        };
554        assert_eq!(
555            b.attention(0, 0, 0, 0, 1, 1, 0, 8, 64, 0.125, false),
556            Err(BackendError::InvalidArgument(
557                "seq_q, seq_kv, and head_dim must all be > 0".into()
558            ))
559        );
560    }
561
562    #[test]
563    fn attention_invalid_scale_error() {
564        let Some(b) = try_init() else {
565            return;
566        };
567        assert_eq!(
568            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, 0.0, false),
569            Err(BackendError::InvalidArgument(
570                "scale must be a positive finite number, got 0".into()
571            ))
572        );
573        assert_eq!(
574            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, -1.0, false),
575            Err(BackendError::InvalidArgument(
576                "scale must be a positive finite number, got -1".into()
577            ))
578        );
579        assert!(
580            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, f64::INFINITY, false)
581                .is_err()
582        );
583    }
584
585    #[test]
586    fn conv2d_wrong_input_rank() {
587        let Some(b) = try_init() else {
588            return;
589        };
590        assert_eq!(
591            b.conv2d_forward(
592                0,
593                &[1, 3, 32],
594                0,
595                &[16, 3, 3, 3],
596                0,
597                &[1, 16, 30, 30],
598                &[1, 1],
599                &[0, 0]
600            ),
601            Err(BackendError::InvalidArgument(
602                "input_shape must have 4 elements (NCHW)".into()
603            ))
604        );
605    }
606
607    #[test]
608    fn conv2d_wrong_filter_rank() {
609        let Some(b) = try_init() else {
610            return;
611        };
612        assert_eq!(
613            b.conv2d_forward(
614                0,
615                &[1, 3, 32, 32],
616                0,
617                &[16, 3, 3],
618                0,
619                &[1, 16, 30, 30],
620                &[1, 1],
621                &[0, 0]
622            ),
623            Err(BackendError::InvalidArgument(
624                "filter_shape must have 4 elements (KCFHFW)".into()
625            ))
626        );
627    }
628
629    // ── Init is idempotent ────────────────────────────────────────────────────
630
631    #[test]
632    fn init_idempotent() {
633        let Some(mut b) = try_init() else {
634            return;
635        };
636        assert_eq!(b.init(), Ok(()));
637        assert!(b.is_initialized());
638    }
639
640    // ── alloc/free/copy roundtrip ─────────────────────────────────────────────
641
642    #[test]
643    fn alloc_copy_roundtrip() {
644        let Some(b) = try_init() else {
645            return;
646        };
647        let src: Vec<u8> = (0u8..64).collect();
648        let handle = match b.alloc(src.len()) {
649            Ok(h) => h,
650            Err(_) => return,
651        };
652        b.copy_htod(handle, &src).expect("copy_htod");
653        let mut dst = vec![0u8; src.len()];
654        b.copy_dtoh(&mut dst, handle).expect("copy_dtoh");
655        assert_eq!(src, dst);
656        b.free(handle).expect("free");
657    }
658
659    // ── Double init is a no-op ────────────────────────────────────────────────
660
661    #[test]
662    fn double_init_is_noop() {
663        let Some(mut b) = try_init() else {
664            return;
665        };
666        let first = b.is_initialized();
667        let _ = b.init();
668        assert_eq!(first, b.is_initialized());
669    }
670
671    // ── alloc and free basic ──────────────────────────────────────────────────
672
673    #[test]
674    fn alloc_and_free_basic() {
675        let Some(b) = try_init() else {
676            return;
677        };
678        match b.alloc(128) {
679            Ok(handle) => {
680                assert!(handle > 0);
681                b.free(handle).expect("free should succeed");
682            }
683            Err(_) => {
684                // Allocation failure is acceptable in environments without GPU.
685            }
686        }
687    }
688}