1use std::sync::Arc;
7
8use oxicuda_backend::{
9 BackendError, BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
10};
11
12use crate::{device::LevelZeroDevice, memory::LevelZeroMemoryManager};
13
14#[derive(Debug)]
32pub struct LevelZeroBackend {
33 device: Option<Arc<LevelZeroDevice>>,
34 memory: Option<Arc<LevelZeroMemoryManager>>,
35 initialized: bool,
36}
37
38impl LevelZeroBackend {
39 pub fn new() -> Self {
41 Self {
42 device: None,
43 memory: None,
44 initialized: false,
45 }
46 }
47
48 fn check_init(&self) -> BackendResult<()> {
50 if self.initialized {
51 Ok(())
52 } else {
53 Err(BackendError::NotInitialized)
54 }
55 }
56
57 fn memory(&self) -> BackendResult<&Arc<LevelZeroMemoryManager>> {
59 self.memory.as_ref().ok_or(BackendError::NotInitialized)
60 }
61}
62
63impl Default for LevelZeroBackend {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl ComputeBackend for LevelZeroBackend {
72 fn name(&self) -> &str {
73 "level-zero"
74 }
75
76 fn init(&mut self) -> BackendResult<()> {
77 if self.initialized {
78 return Ok(());
79 }
80 match LevelZeroDevice::new() {
81 Ok(dev) => {
82 let dev = Arc::new(dev);
83 tracing::info!("Level Zero backend initialised on: {}", dev.name());
84 let memory = LevelZeroMemoryManager::new(Arc::clone(&dev));
85 self.device = Some(dev);
86 self.memory = Some(Arc::new(memory));
87 self.initialized = true;
88 Ok(())
89 }
90 Err(e) => Err(BackendError::from(e)),
91 }
92 }
93
94 fn is_initialized(&self) -> bool {
95 self.initialized
96 }
97
98 fn gemm(
101 &self,
102 _trans_a: BackendTranspose,
103 _trans_b: BackendTranspose,
104 m: usize,
105 n: usize,
106 k: usize,
107 _alpha: f64,
108 _a_ptr: u64,
109 _lda: usize,
110 _b_ptr: u64,
111 _ldb: usize,
112 _beta: f64,
113 _c_ptr: u64,
114 _ldc: usize,
115 ) -> BackendResult<()> {
116 self.check_init()?;
117 if m == 0 || n == 0 || k == 0 {
119 return Ok(());
120 }
121 Err(BackendError::Unsupported(
122 "level-zero: gemm not yet wired".into(),
123 ))
124 }
125
126 fn conv2d_forward(
127 &self,
128 _input_ptr: u64,
129 input_shape: &[usize],
130 _filter_ptr: u64,
131 filter_shape: &[usize],
132 _output_ptr: u64,
133 output_shape: &[usize],
134 stride: &[usize],
135 padding: &[usize],
136 ) -> BackendResult<()> {
137 self.check_init()?;
138
139 if input_shape.len() != 4 {
140 return Err(BackendError::InvalidArgument(
141 "input_shape must have 4 elements (NCHW)".into(),
142 ));
143 }
144 if filter_shape.len() != 4 {
145 return Err(BackendError::InvalidArgument(
146 "filter_shape must have 4 elements (KCFHFW)".into(),
147 ));
148 }
149 if output_shape.len() != 4 {
150 return Err(BackendError::InvalidArgument(
151 "output_shape must have 4 elements (NKOhOw)".into(),
152 ));
153 }
154 if stride.len() != 2 {
155 return Err(BackendError::InvalidArgument(
156 "stride must have 2 elements [sh, sw]".into(),
157 ));
158 }
159 if padding.len() != 2 {
160 return Err(BackendError::InvalidArgument(
161 "padding must have 2 elements [ph, pw]".into(),
162 ));
163 }
164
165 Err(BackendError::Unsupported(
166 "level-zero: conv2d not yet wired".into(),
167 ))
168 }
169
170 fn attention(
171 &self,
172 _q_ptr: u64,
173 _k_ptr: u64,
174 _v_ptr: u64,
175 _o_ptr: u64,
176 _batch: usize,
177 _heads: usize,
178 seq_q: usize,
179 seq_kv: usize,
180 head_dim: usize,
181 scale: f64,
182 _causal: bool,
183 ) -> BackendResult<()> {
184 self.check_init()?;
185
186 if seq_q == 0 || seq_kv == 0 || head_dim == 0 {
187 return Err(BackendError::InvalidArgument(
188 "seq_q, seq_kv, and head_dim must all be > 0".into(),
189 ));
190 }
191 if scale <= 0.0 || !scale.is_finite() {
192 return Err(BackendError::InvalidArgument(format!(
193 "scale must be a positive finite number, got {scale}"
194 )));
195 }
196
197 Err(BackendError::Unsupported(
198 "level-zero: attention not yet wired".into(),
199 ))
200 }
201
202 fn reduce(
203 &self,
204 _op: ReduceOp,
205 _input_ptr: u64,
206 _output_ptr: u64,
207 shape: &[usize],
208 axis: usize,
209 ) -> BackendResult<()> {
210 self.check_init()?;
211
212 if shape.is_empty() {
213 return Err(BackendError::InvalidArgument(
214 "shape must not be empty".into(),
215 ));
216 }
217 if axis >= shape.len() {
218 return Err(BackendError::InvalidArgument(format!(
219 "axis {axis} is out of bounds for shape of length {}",
220 shape.len()
221 )));
222 }
223
224 Err(BackendError::Unsupported(
225 "level-zero: reduce not yet wired".into(),
226 ))
227 }
228
229 fn unary(
230 &self,
231 _op: UnaryOp,
232 _input_ptr: u64,
233 _output_ptr: u64,
234 n: usize,
235 ) -> BackendResult<()> {
236 self.check_init()?;
237 if n == 0 {
238 return Ok(());
239 }
240 Err(BackendError::Unsupported(
241 "level-zero: unary not yet wired".into(),
242 ))
243 }
244
245 fn binary(
246 &self,
247 _op: BinaryOp,
248 _a_ptr: u64,
249 _b_ptr: u64,
250 _output_ptr: u64,
251 n: usize,
252 ) -> BackendResult<()> {
253 self.check_init()?;
254 if n == 0 {
255 return Ok(());
256 }
257 Err(BackendError::Unsupported(
258 "level-zero: binary not yet wired".into(),
259 ))
260 }
261
262 fn synchronize(&self) -> BackendResult<()> {
265 self.check_init()?;
266
267 #[cfg(any(target_os = "linux", target_os = "windows"))]
268 {
269 if let Some(dev) = &self.device {
270 let api = &dev.api;
271 let queue = dev.queue;
272 let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
275 if rc != 0 {
276 return Err(BackendError::DeviceError(format!(
277 "zeCommandQueueSynchronize failed: 0x{rc:08x}"
278 )));
279 }
280 }
281 }
282
283 Ok(())
284 }
285
286 fn alloc(&self, bytes: usize) -> BackendResult<u64> {
289 self.check_init()?;
290 if bytes == 0 {
291 return Err(BackendError::InvalidArgument(
292 "cannot allocate 0 bytes".into(),
293 ));
294 }
295 self.memory()?.alloc(bytes).map_err(BackendError::from)
296 }
297
298 fn free(&self, ptr: u64) -> BackendResult<()> {
299 self.check_init()?;
300 self.memory()?.free(ptr).map_err(BackendError::from)
301 }
302
303 fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
304 self.check_init()?;
305 if src.is_empty() {
306 return Ok(());
307 }
308 self.memory()?
309 .copy_to_device(dst, src)
310 .map_err(BackendError::from)
311 }
312
313 fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
314 self.check_init()?;
315 if dst.is_empty() {
316 return Ok(());
317 }
318 self.memory()?
319 .copy_from_device(dst, src)
320 .map_err(BackendError::from)
321 }
322}
323
324#[cfg(test)]
327mod tests {
328 use super::*;
329 use oxicuda_backend::{BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp};
330
331 #[test]
334 fn level_zero_backend_new_uninitialized() {
335 let b = LevelZeroBackend::new();
336 assert!(!b.is_initialized());
337 }
338
339 #[test]
340 fn level_zero_backend_name() {
341 let b = LevelZeroBackend::new();
342 assert_eq!(b.name(), "level-zero");
343 }
344
345 #[test]
346 fn level_zero_backend_default() {
347 let b = LevelZeroBackend::default();
348 assert!(!b.is_initialized());
349 assert_eq!(b.name(), "level-zero");
350 }
351
352 #[test]
353 fn backend_debug_impl() {
354 let b = LevelZeroBackend::new();
355 let s = format!("{b:?}");
356 assert!(s.contains("LevelZeroBackend"));
357 }
358
359 #[test]
362 fn backend_object_safe() {
363 let b: Box<dyn ComputeBackend> = Box::new(LevelZeroBackend::new());
364 assert_eq!(b.name(), "level-zero");
365 }
366
367 #[test]
370 fn backend_not_initialized_gemm() {
371 let b = LevelZeroBackend::new();
372 let result = b.gemm(
373 BackendTranspose::NoTrans,
374 BackendTranspose::NoTrans,
375 4,
376 4,
377 4,
378 1.0,
379 0,
380 4,
381 0,
382 4,
383 0.0,
384 0,
385 4,
386 );
387 assert_eq!(result, Err(BackendError::NotInitialized));
388 }
389
390 #[test]
391 fn backend_not_initialized_alloc() {
392 let b = LevelZeroBackend::new();
393 assert_eq!(b.alloc(1024), Err(BackendError::NotInitialized));
394 }
395
396 #[test]
397 fn backend_not_initialized_synchronize() {
398 let b = LevelZeroBackend::new();
399 assert_eq!(b.synchronize(), Err(BackendError::NotInitialized));
400 }
401
402 #[test]
403 fn backend_not_initialized_free() {
404 let b = LevelZeroBackend::new();
405 assert_eq!(b.free(1), Err(BackendError::NotInitialized));
406 }
407
408 #[test]
409 fn backend_not_initialized_copy_htod() {
410 let b = LevelZeroBackend::new();
411 assert_eq!(b.copy_htod(1, b"hello"), Err(BackendError::NotInitialized));
412 }
413
414 #[test]
415 fn backend_not_initialized_copy_dtoh() {
416 let b = LevelZeroBackend::new();
417 let mut buf = [0u8; 4];
418 assert_eq!(b.copy_dtoh(&mut buf, 1), Err(BackendError::NotInitialized));
419 }
420
421 fn try_init() -> Option<LevelZeroBackend> {
424 let mut b = LevelZeroBackend::new();
425 match b.init() {
426 Ok(()) => Some(b),
427 Err(_) => None,
428 }
429 }
430
431 #[test]
434 fn init_graceful_failure() {
435 let mut b = LevelZeroBackend::new();
437 let _result = b.init();
438 }
440
441 #[test]
444 fn alloc_zero_bytes_error() {
445 let Some(b) = try_init() else {
446 return;
447 };
448 assert_eq!(
449 b.alloc(0),
450 Err(BackendError::InvalidArgument(
451 "cannot allocate 0 bytes".into()
452 ))
453 );
454 }
455
456 #[test]
457 fn copy_htod_empty_noop() {
458 let Some(b) = try_init() else {
459 return;
460 };
461 assert_eq!(b.copy_htod(0, &[]), Ok(()));
462 }
463
464 #[test]
465 fn copy_dtoh_empty_noop() {
466 let Some(b) = try_init() else {
467 return;
468 };
469 assert_eq!(b.copy_dtoh(&mut [], 0), Ok(()));
470 }
471
472 #[test]
473 fn gemm_zero_dims_noop() {
474 let Some(b) = try_init() else {
475 return;
476 };
477 assert_eq!(
478 b.gemm(
479 BackendTranspose::NoTrans,
480 BackendTranspose::NoTrans,
481 0,
482 0,
483 0,
484 1.0,
485 0,
486 1,
487 0,
488 1,
489 0.0,
490 0,
491 1
492 ),
493 Ok(())
494 );
495 }
496
497 #[test]
498 fn unary_zero_n_noop() {
499 let Some(b) = try_init() else {
500 return;
501 };
502 assert_eq!(b.unary(UnaryOp::Relu, 0, 0, 0), Ok(()));
503 }
504
505 #[test]
506 fn binary_zero_n_noop() {
507 let Some(b) = try_init() else {
508 return;
509 };
510 assert_eq!(b.binary(BinaryOp::Add, 0, 0, 0, 0), Ok(()));
511 }
512
513 #[test]
514 fn synchronize_after_init() {
515 let Some(b) = try_init() else {
516 return;
517 };
518 assert_eq!(b.synchronize(), Ok(()));
519 }
520
521 #[test]
524 fn reduce_empty_shape_error() {
525 let Some(b) = try_init() else {
526 return;
527 };
528 assert_eq!(
529 b.reduce(ReduceOp::Sum, 0, 0, &[], 0),
530 Err(BackendError::InvalidArgument(
531 "shape must not be empty".into()
532 ))
533 );
534 }
535
536 #[test]
537 fn reduce_axis_out_of_bounds_error() {
538 let Some(b) = try_init() else {
539 return;
540 };
541 assert_eq!(
542 b.reduce(ReduceOp::Sum, 0, 0, &[4, 4], 5),
543 Err(BackendError::InvalidArgument(
544 "axis 5 is out of bounds for shape of length 2".into()
545 ))
546 );
547 }
548
549 #[test]
550 fn attention_zero_seq_error() {
551 let Some(b) = try_init() else {
552 return;
553 };
554 assert_eq!(
555 b.attention(0, 0, 0, 0, 1, 1, 0, 8, 64, 0.125, false),
556 Err(BackendError::InvalidArgument(
557 "seq_q, seq_kv, and head_dim must all be > 0".into()
558 ))
559 );
560 }
561
562 #[test]
563 fn attention_invalid_scale_error() {
564 let Some(b) = try_init() else {
565 return;
566 };
567 assert_eq!(
568 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, 0.0, false),
569 Err(BackendError::InvalidArgument(
570 "scale must be a positive finite number, got 0".into()
571 ))
572 );
573 assert_eq!(
574 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, -1.0, false),
575 Err(BackendError::InvalidArgument(
576 "scale must be a positive finite number, got -1".into()
577 ))
578 );
579 assert!(
580 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, f64::INFINITY, false)
581 .is_err()
582 );
583 }
584
585 #[test]
586 fn conv2d_wrong_input_rank() {
587 let Some(b) = try_init() else {
588 return;
589 };
590 assert_eq!(
591 b.conv2d_forward(
592 0,
593 &[1, 3, 32],
594 0,
595 &[16, 3, 3, 3],
596 0,
597 &[1, 16, 30, 30],
598 &[1, 1],
599 &[0, 0]
600 ),
601 Err(BackendError::InvalidArgument(
602 "input_shape must have 4 elements (NCHW)".into()
603 ))
604 );
605 }
606
607 #[test]
608 fn conv2d_wrong_filter_rank() {
609 let Some(b) = try_init() else {
610 return;
611 };
612 assert_eq!(
613 b.conv2d_forward(
614 0,
615 &[1, 3, 32, 32],
616 0,
617 &[16, 3, 3],
618 0,
619 &[1, 16, 30, 30],
620 &[1, 1],
621 &[0, 0]
622 ),
623 Err(BackendError::InvalidArgument(
624 "filter_shape must have 4 elements (KCFHFW)".into()
625 ))
626 );
627 }
628
629 #[test]
632 fn init_idempotent() {
633 let Some(mut b) = try_init() else {
634 return;
635 };
636 assert_eq!(b.init(), Ok(()));
637 assert!(b.is_initialized());
638 }
639
640 #[test]
643 fn alloc_copy_roundtrip() {
644 let Some(b) = try_init() else {
645 return;
646 };
647 let src: Vec<u8> = (0u8..64).collect();
648 let handle = match b.alloc(src.len()) {
649 Ok(h) => h,
650 Err(_) => return,
651 };
652 b.copy_htod(handle, &src).expect("copy_htod");
653 let mut dst = vec![0u8; src.len()];
654 b.copy_dtoh(&mut dst, handle).expect("copy_dtoh");
655 assert_eq!(src, dst);
656 b.free(handle).expect("free");
657 }
658
659 #[test]
662 fn double_init_is_noop() {
663 let Some(mut b) = try_init() else {
664 return;
665 };
666 let first = b.is_initialized();
667 let _ = b.init();
668 assert_eq!(first, b.is_initialized());
669 }
670
671 #[test]
674 fn alloc_and_free_basic() {
675 let Some(b) = try_init() else {
676 return;
677 };
678 match b.alloc(128) {
679 Ok(handle) => {
680 assert!(handle > 0);
681 b.free(handle).expect("free should succeed");
682 }
683 Err(_) => {
684 }
686 }
687 }
688}