1use std::sync::Arc;
7
8use oxicuda_backend::{
9 BackendError, BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
10};
11
12use crate::{device::LevelZeroDevice, memory::LevelZeroMemoryManager};
13
14#[derive(Debug)]
32pub struct LevelZeroBackend {
33 device: Option<Arc<LevelZeroDevice>>,
34 memory: Option<Arc<LevelZeroMemoryManager>>,
35 initialized: bool,
36}
37
38impl LevelZeroBackend {
39 pub fn new() -> Self {
41 Self {
42 device: None,
43 memory: None,
44 initialized: false,
45 }
46 }
47
48 fn check_init(&self) -> BackendResult<()> {
50 if self.initialized {
51 Ok(())
52 } else {
53 Err(BackendError::NotInitialized)
54 }
55 }
56
57 fn memory(&self) -> BackendResult<&Arc<LevelZeroMemoryManager>> {
59 self.memory.as_ref().ok_or(BackendError::NotInitialized)
60 }
61}
62
63impl Default for LevelZeroBackend {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl ComputeBackend for LevelZeroBackend {
72 fn name(&self) -> &str {
73 "level-zero"
74 }
75
76 fn init(&mut self) -> BackendResult<()> {
77 if self.initialized {
78 return Ok(());
79 }
80 match LevelZeroDevice::new() {
81 Ok(dev) => {
82 let dev = Arc::new(dev);
83 tracing::info!("Level Zero backend initialised on: {}", dev.name());
84 let memory = LevelZeroMemoryManager::new(Arc::clone(&dev));
85 self.device = Some(dev);
86 self.memory = Some(Arc::new(memory));
87 self.initialized = true;
88 Ok(())
89 }
90 Err(e) => Err(BackendError::from(e)),
91 }
92 }
93
94 fn is_initialized(&self) -> bool {
95 self.initialized
96 }
97
98 fn gemm(
101 &self,
102 _trans_a: BackendTranspose,
103 _trans_b: BackendTranspose,
104 m: usize,
105 n: usize,
106 k: usize,
107 alpha: f64,
108 a_ptr: u64,
109 _lda: usize,
110 b_ptr: u64,
111 _ldb: usize,
112 beta: f64,
113 c_ptr: u64,
114 _ldc: usize,
115 ) -> BackendResult<()> {
116 self.check_init()?;
117 if m == 0 || n == 0 || k == 0 {
118 return Ok(());
119 }
120 self.dispatch_gemm(m, n, k, alpha as f32, a_ptr, b_ptr, beta as f32, c_ptr)
121 }
122
123 fn conv2d_forward(
124 &self,
125 input_ptr: u64,
126 input_shape: &[usize],
127 filter_ptr: u64,
128 filter_shape: &[usize],
129 output_ptr: u64,
130 output_shape: &[usize],
131 stride: &[usize],
132 padding: &[usize],
133 ) -> BackendResult<()> {
134 self.check_init()?;
135
136 if input_shape.len() != 4 {
137 return Err(BackendError::InvalidArgument(
138 "input_shape must have 4 elements (NCHW)".into(),
139 ));
140 }
141 if filter_shape.len() != 4 {
142 return Err(BackendError::InvalidArgument(
143 "filter_shape must have 4 elements (KCFHFW)".into(),
144 ));
145 }
146 if output_shape.len() != 4 {
147 return Err(BackendError::InvalidArgument(
148 "output_shape must have 4 elements (NKOhOw)".into(),
149 ));
150 }
151 if stride.len() != 2 {
152 return Err(BackendError::InvalidArgument(
153 "stride must have 2 elements [sh, sw]".into(),
154 ));
155 }
156 if padding.len() != 2 {
157 return Err(BackendError::InvalidArgument(
158 "padding must have 2 elements [ph, pw]".into(),
159 ));
160 }
161
162 let n = input_shape[0];
163 let c_in = input_shape[1];
164 let h_in = input_shape[2];
165 let w_in = input_shape[3];
166 let k_out = filter_shape[0];
167 let fh = filter_shape[2];
168 let fw = filter_shape[3];
169 let o_h = output_shape[2];
170 let o_w = output_shape[3];
171 let stride_h = stride[0];
172 let stride_w = stride[1];
173 let pad_h = padding[0];
174 let pad_w = padding[1];
175
176 let in_len = n * c_in * h_in * w_in;
178 let flt_len = k_out * c_in * fh * fw;
179 let out_len = n * k_out * o_h * o_w;
180
181 let mut in_bytes = vec![0u8; in_len * 4];
182 self.copy_dtoh(&mut in_bytes, input_ptr)?;
183 let inp: Vec<f32> = in_bytes
184 .chunks_exact(4)
185 .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
186 .collect();
187
188 let mut flt_bytes = vec![0u8; flt_len * 4];
189 self.copy_dtoh(&mut flt_bytes, filter_ptr)?;
190 let flt: Vec<f32> = flt_bytes
191 .chunks_exact(4)
192 .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
193 .collect();
194
195 let mut out = vec![0.0f32; out_len];
197 for b_idx in 0..n {
198 for kf in 0..k_out {
199 for oy in 0..o_h {
200 for ox in 0..o_w {
201 let mut acc = 0.0f32;
202 for ci in 0..c_in {
203 for fy in 0..fh {
204 for fx in 0..fw {
205 let iy = (oy * stride_h + fy) as isize - pad_h as isize;
206 let ix = (ox * stride_w + fx) as isize - pad_w as isize;
207 if iy >= 0
208 && (iy as usize) < h_in
209 && ix >= 0
210 && (ix as usize) < w_in
211 {
212 let iy = iy as usize;
213 let ix = ix as usize;
214 acc += inp[((b_idx * c_in + ci) * h_in + iy) * w_in + ix]
215 * flt[((kf * c_in + ci) * fh + fy) * fw + fx];
216 }
217 }
218 }
219 }
220 out[((b_idx * k_out + kf) * o_h + oy) * o_w + ox] = acc;
221 }
222 }
223 }
224 }
225
226 let out_bytes: Vec<u8> = out.iter().flat_map(|f| f.to_ne_bytes()).collect();
227 self.copy_htod(output_ptr, &out_bytes)
228 }
229
230 fn attention(
231 &self,
232 q_ptr: u64,
233 k_ptr: u64,
234 v_ptr: u64,
235 o_ptr: u64,
236 batch: usize,
237 heads: usize,
238 seq_q: usize,
239 seq_kv: usize,
240 head_dim: usize,
241 scale: f64,
242 causal: bool,
243 ) -> BackendResult<()> {
244 self.check_init()?;
245
246 if seq_q == 0 || seq_kv == 0 || head_dim == 0 {
247 return Err(BackendError::InvalidArgument(
248 "seq_q, seq_kv, and head_dim must all be > 0".into(),
249 ));
250 }
251 if scale <= 0.0 || !scale.is_finite() {
252 return Err(BackendError::InvalidArgument(format!(
253 "scale must be a positive finite number, got {scale}"
254 )));
255 }
256
257 let batch_heads = batch * heads;
258 let scale_f32 = scale as f32;
259
260 let q_len = batch_heads * seq_q * head_dim;
262 let kv_len = batch_heads * seq_kv * head_dim;
263
264 let mut q_bytes = vec![0u8; q_len * 4];
265 self.copy_dtoh(&mut q_bytes, q_ptr)?;
266 let q: Vec<f32> = q_bytes
267 .chunks_exact(4)
268 .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
269 .collect();
270
271 let mut k_bytes = vec![0u8; kv_len * 4];
272 self.copy_dtoh(&mut k_bytes, k_ptr)?;
273 let k: Vec<f32> = k_bytes
274 .chunks_exact(4)
275 .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
276 .collect();
277
278 let mut v_bytes = vec![0u8; kv_len * 4];
279 self.copy_dtoh(&mut v_bytes, v_ptr)?;
280 let v: Vec<f32> = v_bytes
281 .chunks_exact(4)
282 .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
283 .collect();
284
285 let mut output = vec![0.0f32; q_len];
287
288 for bh in 0..batch_heads {
289 for sq in 0..seq_q {
290 let q_off = (bh * seq_q + sq) * head_dim;
291 let o_off = q_off;
292
293 let mut max_score = f32::NEG_INFINITY;
295 for sk in 0..seq_kv {
296 if causal && sk > sq {
297 continue;
298 }
299 let k_off = (bh * seq_kv + sk) * head_dim;
300 let mut dot = 0.0f32;
301 for d in 0..head_dim {
302 dot += q[q_off + d] * k[k_off + d];
303 }
304 let score = dot * scale_f32;
305 if score > max_score {
306 max_score = score;
307 }
308 }
309
310 if max_score == f32::NEG_INFINITY {
311 max_score = 0.0;
312 }
313
314 let mut sum_exp = 0.0f32;
316 for sk in 0..seq_kv {
317 if causal && sk > sq {
318 continue;
319 }
320 let k_off = (bh * seq_kv + sk) * head_dim;
321 let v_off = (bh * seq_kv + sk) * head_dim;
322 let mut dot = 0.0f32;
323 for d in 0..head_dim {
324 dot += q[q_off + d] * k[k_off + d];
325 }
326 let w = (dot * scale_f32 - max_score).exp();
327 sum_exp += w;
328 for d in 0..head_dim {
329 output[o_off + d] += w * v[v_off + d];
330 }
331 }
332
333 if sum_exp > 0.0 {
335 for d in 0..head_dim {
336 output[o_off + d] /= sum_exp;
337 }
338 }
339 }
340 }
341
342 let o_bytes: Vec<u8> = output.iter().flat_map(|f| f.to_ne_bytes()).collect();
343 self.copy_htod(o_ptr, &o_bytes)
344 }
345
346 fn reduce(
347 &self,
348 op: ReduceOp,
349 input_ptr: u64,
350 output_ptr: u64,
351 shape: &[usize],
352 axis: usize,
353 ) -> BackendResult<()> {
354 self.check_init()?;
355
356 if shape.is_empty() {
357 return Err(BackendError::InvalidArgument(
358 "shape must not be empty".into(),
359 ));
360 }
361 if axis >= shape.len() {
362 return Err(BackendError::InvalidArgument(format!(
363 "axis {axis} is out of bounds for shape of length {}",
364 shape.len()
365 )));
366 }
367
368 self.dispatch_reduce(op, input_ptr, output_ptr, shape, axis)
369 }
370
371 fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()> {
372 self.check_init()?;
373 if n == 0 {
374 return Ok(());
375 }
376 self.dispatch_unary(op, input_ptr, output_ptr, n)
377 }
378
379 fn binary(
380 &self,
381 op: BinaryOp,
382 a_ptr: u64,
383 b_ptr: u64,
384 output_ptr: u64,
385 n: usize,
386 ) -> BackendResult<()> {
387 self.check_init()?;
388 if n == 0 {
389 return Ok(());
390 }
391 self.dispatch_binary(op, a_ptr, b_ptr, output_ptr, n)
392 }
393
394 fn synchronize(&self) -> BackendResult<()> {
397 self.check_init()?;
398
399 #[cfg(any(target_os = "linux", target_os = "windows"))]
400 {
401 if let Some(dev) = &self.device {
402 let api = &dev.api;
403 let queue = dev.queue;
404 let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
407 if rc != 0 {
408 return Err(BackendError::DeviceError(format!(
409 "zeCommandQueueSynchronize failed: 0x{rc:08x}"
410 )));
411 }
412 }
413 }
414
415 Ok(())
416 }
417
418 fn alloc(&self, bytes: usize) -> BackendResult<u64> {
421 self.check_init()?;
422 if bytes == 0 {
423 return Err(BackendError::InvalidArgument(
424 "cannot allocate 0 bytes".into(),
425 ));
426 }
427 self.memory()?.alloc(bytes).map_err(BackendError::from)
428 }
429
430 fn free(&self, ptr: u64) -> BackendResult<()> {
431 self.check_init()?;
432 self.memory()?.free(ptr).map_err(BackendError::from)
433 }
434
435 fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
436 self.check_init()?;
437 if src.is_empty() {
438 return Ok(());
439 }
440 self.memory()?
441 .copy_to_device(dst, src)
442 .map_err(BackendError::from)
443 }
444
445 fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
446 self.check_init()?;
447 if dst.is_empty() {
448 return Ok(());
449 }
450 self.memory()?
451 .copy_from_device(dst, src)
452 .map_err(BackendError::from)
453 }
454}
455
456const WORKGROUP_SIZE: u32 = crate::spirv::WORKGROUP_SIZE;
460
461#[cfg_attr(not(any(target_os = "linux", target_os = "windows")), allow(dead_code))]
463enum KernelArg {
464 Buffer(u64),
466 U32(u32),
468 F32(f32),
470}
471
472impl LevelZeroBackend {
473 fn run_kernel(
481 &self,
482 spv_words: &[u32],
483 args: &[KernelArg],
484 workgroups: u32,
485 ) -> BackendResult<()> {
486 #[cfg(any(target_os = "linux", target_os = "windows"))]
487 {
488 use std::ffi::c_void;
489
490 use crate::device::{
491 ZE_MODULE_FORMAT_IL_SPIRV, ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
492 ZE_STRUCTURE_TYPE_KERNEL_DESC, ZE_STRUCTURE_TYPE_MODULE_DESC, ZeCommandListDesc,
493 ZeGroupCount, ZeKernelDesc, ZeKernelHandle, ZeModuleDesc, ZeModuleHandle,
494 };
495
496 let device = self.device.as_ref().ok_or(BackendError::NotInitialized)?;
497 let memory = self.memory()?;
498 let api = &device.api;
499 let context = device.context;
500 let dev_handle = device.device;
501 let queue = device.queue;
502
503 let spv_bytes: Vec<u8> = spv_words.iter().flat_map(|w| w.to_ne_bytes()).collect();
505
506 let module_desc = ZeModuleDesc {
508 stype: ZE_STRUCTURE_TYPE_MODULE_DESC,
509 p_next: std::ptr::null(),
510 format: ZE_MODULE_FORMAT_IL_SPIRV,
511 input_size: spv_bytes.len(),
512 p_input_module: spv_bytes.as_ptr(),
513 p_build_flags: std::ptr::null(),
514 p_constants: std::ptr::null(),
515 };
516 let mut module: ZeModuleHandle = std::ptr::null_mut();
517 let rc = unsafe {
518 (api.ze_module_create)(
519 context,
520 dev_handle,
521 &module_desc,
522 &mut module as *mut ZeModuleHandle,
523 std::ptr::null_mut(),
524 )
525 };
526 if rc != 0 {
527 return Err(BackendError::DeviceError(format!(
528 "zeModuleCreate failed: 0x{rc:08x}"
529 )));
530 }
531
532 let kernel_name = b"main\0";
534 let kernel_desc = ZeKernelDesc {
535 stype: ZE_STRUCTURE_TYPE_KERNEL_DESC,
536 p_next: std::ptr::null(),
537 flags: 0,
538 p_kernel_name: kernel_name.as_ptr(),
539 };
540 let mut kernel: ZeKernelHandle = std::ptr::null_mut();
541 let rc = unsafe {
542 (api.ze_kernel_create)(module, &kernel_desc, &mut kernel as *mut ZeKernelHandle)
543 };
544 if rc != 0 {
545 unsafe { (api.ze_module_destroy)(module) };
546 return Err(BackendError::DeviceError(format!(
547 "zeKernelCreate failed: 0x{rc:08x}"
548 )));
549 }
550
551 let rc = unsafe { (api.ze_kernel_set_group_size)(kernel, WORKGROUP_SIZE, 1, 1) };
553 if rc != 0 {
554 unsafe {
555 (api.ze_kernel_destroy)(kernel);
556 (api.ze_module_destroy)(module);
557 }
558 return Err(BackendError::DeviceError(format!(
559 "zeKernelSetGroupSize failed: 0x{rc:08x}"
560 )));
561 }
562
563 for (idx, arg) in args.iter().enumerate() {
565 let rc = match arg {
566 KernelArg::Buffer(handle) => {
567 let dev_ptr = memory.device_ptr(*handle).map_err(|e| {
568 unsafe {
569 (api.ze_kernel_destroy)(kernel);
570 (api.ze_module_destroy)(module);
571 }
572 BackendError::from(e)
573 })?;
574 unsafe {
575 (api.ze_kernel_set_argument_value)(
576 kernel,
577 idx as u32,
578 std::mem::size_of::<*mut c_void>(),
579 &dev_ptr as *const *mut c_void as *const c_void,
580 )
581 }
582 }
583 KernelArg::U32(val) => unsafe {
584 (api.ze_kernel_set_argument_value)(
585 kernel,
586 idx as u32,
587 std::mem::size_of::<u32>(),
588 val as *const u32 as *const c_void,
589 )
590 },
591 KernelArg::F32(val) => unsafe {
592 (api.ze_kernel_set_argument_value)(
593 kernel,
594 idx as u32,
595 std::mem::size_of::<f32>(),
596 val as *const f32 as *const c_void,
597 )
598 },
599 };
600 if rc != 0 {
601 unsafe {
602 (api.ze_kernel_destroy)(kernel);
603 (api.ze_module_destroy)(module);
604 }
605 return Err(BackendError::DeviceError(format!(
606 "zeKernelSetArgumentValue(arg={idx}) failed: 0x{rc:08x}"
607 )));
608 }
609 }
610
611 let list_desc = ZeCommandListDesc {
613 stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
614 p_next: std::ptr::null(),
615 command_queue_group_ordinal: 0,
616 flags: 0,
617 };
618 let mut list = std::ptr::null_mut();
619 let rc =
620 unsafe { (api.ze_command_list_create)(context, dev_handle, &list_desc, &mut list) };
621 if rc != 0 {
622 unsafe {
623 (api.ze_kernel_destroy)(kernel);
624 (api.ze_module_destroy)(module);
625 }
626 return Err(BackendError::DeviceError(format!(
627 "zeCommandListCreate failed: 0x{rc:08x}"
628 )));
629 }
630
631 let group_count = ZeGroupCount {
633 group_count_x: workgroups,
634 group_count_y: 1,
635 group_count_z: 1,
636 };
637 let rc = unsafe {
638 (api.ze_command_list_append_launch_kernel)(
639 list,
640 kernel,
641 &group_count,
642 0,
643 0,
644 std::ptr::null(),
645 )
646 };
647 if rc != 0 {
648 unsafe {
649 (api.ze_command_list_destroy)(list);
650 (api.ze_kernel_destroy)(kernel);
651 (api.ze_module_destroy)(module);
652 }
653 return Err(BackendError::DeviceError(format!(
654 "zeCommandListAppendLaunchKernel failed: 0x{rc:08x}"
655 )));
656 }
657
658 let rc = unsafe { (api.ze_command_list_close)(list) };
660 if rc != 0 {
661 unsafe {
662 (api.ze_command_list_destroy)(list);
663 (api.ze_kernel_destroy)(kernel);
664 (api.ze_module_destroy)(module);
665 }
666 return Err(BackendError::DeviceError(format!(
667 "zeCommandListClose failed: 0x{rc:08x}"
668 )));
669 }
670
671 let rc = unsafe { (api.ze_command_queue_execute_command_lists)(queue, 1, &list, 0) };
672 if rc != 0 {
673 unsafe {
674 (api.ze_command_list_destroy)(list);
675 (api.ze_kernel_destroy)(kernel);
676 (api.ze_module_destroy)(module);
677 }
678 return Err(BackendError::DeviceError(format!(
679 "zeCommandQueueExecuteCommandLists failed: 0x{rc:08x}"
680 )));
681 }
682
683 let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
684 if rc != 0 {
685 unsafe {
686 (api.ze_command_list_destroy)(list);
687 (api.ze_kernel_destroy)(kernel);
688 (api.ze_module_destroy)(module);
689 }
690 return Err(BackendError::DeviceError(format!(
691 "zeCommandQueueSynchronize failed: 0x{rc:08x}"
692 )));
693 }
694
695 unsafe {
697 (api.ze_command_list_destroy)(list);
698 (api.ze_kernel_destroy)(kernel);
699 (api.ze_module_destroy)(module);
700 }
701
702 Ok(())
703 }
704
705 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
706 {
707 let _ = (spv_words, args, workgroups);
708 Err(BackendError::DeviceError(
709 "Level Zero requires Linux or Windows".into(),
710 ))
711 }
712 }
713
714 fn dispatch_unary(
715 &self,
716 op: UnaryOp,
717 input_ptr: u64,
718 output_ptr: u64,
719 n: usize,
720 ) -> BackendResult<()> {
721 let spv = crate::spirv::unary_compute_shader(op);
722 let args = [
723 KernelArg::Buffer(input_ptr),
724 KernelArg::Buffer(output_ptr),
725 KernelArg::U32(n as u32),
726 ];
727 self.run_kernel(&spv, &args, (n as u32).div_ceil(WORKGROUP_SIZE))
728 }
729
730 fn dispatch_binary(
731 &self,
732 op: BinaryOp,
733 a_ptr: u64,
734 b_ptr: u64,
735 output_ptr: u64,
736 n: usize,
737 ) -> BackendResult<()> {
738 let spv = crate::spirv::binary_compute_shader(op);
739 let args = [
740 KernelArg::Buffer(a_ptr),
741 KernelArg::Buffer(b_ptr),
742 KernelArg::Buffer(output_ptr),
743 KernelArg::U32(n as u32),
744 ];
745 self.run_kernel(&spv, &args, (n as u32).div_ceil(WORKGROUP_SIZE))
746 }
747
748 fn dispatch_reduce(
749 &self,
750 op: ReduceOp,
751 input_ptr: u64,
752 output_ptr: u64,
753 shape: &[usize],
754 axis: usize,
755 ) -> BackendResult<()> {
756 let outer_size: usize = shape[..axis].iter().product::<usize>().max(1);
757 let reduce_size = shape[axis];
758 let inner_size: usize = shape[axis + 1..].iter().product::<usize>().max(1);
759
760 let spv = crate::spirv::reduce_compute_shader(op);
761 let total_output = (outer_size * inner_size) as u32;
762 let args = [
763 KernelArg::Buffer(input_ptr),
764 KernelArg::Buffer(output_ptr),
765 KernelArg::U32(outer_size as u32),
766 KernelArg::U32(reduce_size as u32),
767 KernelArg::U32(inner_size as u32),
768 ];
769 self.run_kernel(&spv, &args, total_output.div_ceil(WORKGROUP_SIZE))
770 }
771
772 #[allow(clippy::too_many_arguments)]
773 fn dispatch_gemm(
774 &self,
775 m: usize,
776 n: usize,
777 k: usize,
778 alpha: f32,
779 a_ptr: u64,
780 b_ptr: u64,
781 beta: f32,
782 c_ptr: u64,
783 ) -> BackendResult<()> {
784 let spv = crate::spirv::gemm_compute_shader();
785 let total = (m * n) as u32;
786 let args = [
787 KernelArg::Buffer(a_ptr),
788 KernelArg::Buffer(b_ptr),
789 KernelArg::Buffer(c_ptr),
790 KernelArg::U32(m as u32),
791 KernelArg::U32(n as u32),
792 KernelArg::U32(k as u32),
793 KernelArg::F32(alpha),
794 KernelArg::F32(beta),
795 ];
796 self.run_kernel(&spv, &args, total.div_ceil(WORKGROUP_SIZE))
797 }
798}
799
800#[cfg(test)]
803mod tests {
804 use super::*;
805 use oxicuda_backend::{BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp};
806
807 #[test]
810 fn level_zero_backend_new_uninitialized() {
811 let b = LevelZeroBackend::new();
812 assert!(!b.is_initialized());
813 }
814
815 #[test]
816 fn level_zero_backend_name() {
817 let b = LevelZeroBackend::new();
818 assert_eq!(b.name(), "level-zero");
819 }
820
821 #[test]
822 fn level_zero_backend_default() {
823 let b = LevelZeroBackend::default();
824 assert!(!b.is_initialized());
825 assert_eq!(b.name(), "level-zero");
826 }
827
828 #[test]
829 fn backend_debug_impl() {
830 let b = LevelZeroBackend::new();
831 let s = format!("{b:?}");
832 assert!(s.contains("LevelZeroBackend"));
833 }
834
835 #[test]
838 fn backend_object_safe() {
839 let b: Box<dyn ComputeBackend> = Box::new(LevelZeroBackend::new());
840 assert_eq!(b.name(), "level-zero");
841 }
842
843 #[test]
846 fn backend_not_initialized_gemm() {
847 let b = LevelZeroBackend::new();
848 let result = b.gemm(
849 BackendTranspose::NoTrans,
850 BackendTranspose::NoTrans,
851 4,
852 4,
853 4,
854 1.0,
855 0,
856 4,
857 0,
858 4,
859 0.0,
860 0,
861 4,
862 );
863 assert_eq!(result, Err(BackendError::NotInitialized));
864 }
865
866 #[test]
867 fn backend_not_initialized_alloc() {
868 let b = LevelZeroBackend::new();
869 assert_eq!(b.alloc(1024), Err(BackendError::NotInitialized));
870 }
871
872 #[test]
873 fn backend_not_initialized_synchronize() {
874 let b = LevelZeroBackend::new();
875 assert_eq!(b.synchronize(), Err(BackendError::NotInitialized));
876 }
877
878 #[test]
879 fn backend_not_initialized_free() {
880 let b = LevelZeroBackend::new();
881 assert_eq!(b.free(1), Err(BackendError::NotInitialized));
882 }
883
884 #[test]
885 fn backend_not_initialized_copy_htod() {
886 let b = LevelZeroBackend::new();
887 assert_eq!(b.copy_htod(1, b"hello"), Err(BackendError::NotInitialized));
888 }
889
890 #[test]
891 fn backend_not_initialized_copy_dtoh() {
892 let b = LevelZeroBackend::new();
893 let mut buf = [0u8; 4];
894 assert_eq!(b.copy_dtoh(&mut buf, 1), Err(BackendError::NotInitialized));
895 }
896
897 fn try_init() -> Option<LevelZeroBackend> {
900 let mut b = LevelZeroBackend::new();
901 match b.init() {
902 Ok(()) => Some(b),
903 Err(_) => None,
904 }
905 }
906
907 #[test]
910 fn init_graceful_failure() {
911 let mut b = LevelZeroBackend::new();
913 let _result = b.init();
914 }
916
917 #[test]
920 fn alloc_zero_bytes_error() {
921 let Some(b) = try_init() else {
922 return;
923 };
924 assert_eq!(
925 b.alloc(0),
926 Err(BackendError::InvalidArgument(
927 "cannot allocate 0 bytes".into()
928 ))
929 );
930 }
931
932 #[test]
933 fn copy_htod_empty_noop() {
934 let Some(b) = try_init() else {
935 return;
936 };
937 assert_eq!(b.copy_htod(0, &[]), Ok(()));
938 }
939
940 #[test]
941 fn copy_dtoh_empty_noop() {
942 let Some(b) = try_init() else {
943 return;
944 };
945 assert_eq!(b.copy_dtoh(&mut [], 0), Ok(()));
946 }
947
948 #[test]
949 fn gemm_zero_dims_noop() {
950 let Some(b) = try_init() else {
951 return;
952 };
953 assert_eq!(
954 b.gemm(
955 BackendTranspose::NoTrans,
956 BackendTranspose::NoTrans,
957 0,
958 0,
959 0,
960 1.0,
961 0,
962 1,
963 0,
964 1,
965 0.0,
966 0,
967 1
968 ),
969 Ok(())
970 );
971 }
972
973 #[test]
974 fn unary_zero_n_noop() {
975 let Some(b) = try_init() else {
976 return;
977 };
978 assert_eq!(b.unary(UnaryOp::Relu, 0, 0, 0), Ok(()));
979 }
980
981 #[test]
982 fn binary_zero_n_noop() {
983 let Some(b) = try_init() else {
984 return;
985 };
986 assert_eq!(b.binary(BinaryOp::Add, 0, 0, 0, 0), Ok(()));
987 }
988
989 #[test]
990 fn synchronize_after_init() {
991 let Some(b) = try_init() else {
992 return;
993 };
994 assert_eq!(b.synchronize(), Ok(()));
995 }
996
997 #[test]
1000 fn reduce_empty_shape_error() {
1001 let Some(b) = try_init() else {
1002 return;
1003 };
1004 assert_eq!(
1005 b.reduce(ReduceOp::Sum, 0, 0, &[], 0),
1006 Err(BackendError::InvalidArgument(
1007 "shape must not be empty".into()
1008 ))
1009 );
1010 }
1011
1012 #[test]
1013 fn reduce_axis_out_of_bounds_error() {
1014 let Some(b) = try_init() else {
1015 return;
1016 };
1017 assert_eq!(
1018 b.reduce(ReduceOp::Sum, 0, 0, &[4, 4], 5),
1019 Err(BackendError::InvalidArgument(
1020 "axis 5 is out of bounds for shape of length 2".into()
1021 ))
1022 );
1023 }
1024
1025 #[test]
1026 fn attention_zero_seq_error() {
1027 let Some(b) = try_init() else {
1028 return;
1029 };
1030 assert_eq!(
1031 b.attention(0, 0, 0, 0, 1, 1, 0, 8, 64, 0.125, false),
1032 Err(BackendError::InvalidArgument(
1033 "seq_q, seq_kv, and head_dim must all be > 0".into()
1034 ))
1035 );
1036 }
1037
1038 #[test]
1039 fn attention_invalid_scale_error() {
1040 let Some(b) = try_init() else {
1041 return;
1042 };
1043 assert_eq!(
1044 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, 0.0, false),
1045 Err(BackendError::InvalidArgument(
1046 "scale must be a positive finite number, got 0".into()
1047 ))
1048 );
1049 assert_eq!(
1050 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, -1.0, false),
1051 Err(BackendError::InvalidArgument(
1052 "scale must be a positive finite number, got -1".into()
1053 ))
1054 );
1055 assert!(
1056 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, f64::INFINITY, false)
1057 .is_err()
1058 );
1059 }
1060
1061 #[test]
1062 fn conv2d_wrong_input_rank() {
1063 let Some(b) = try_init() else {
1064 return;
1065 };
1066 assert_eq!(
1067 b.conv2d_forward(
1068 0,
1069 &[1, 3, 32],
1070 0,
1071 &[16, 3, 3, 3],
1072 0,
1073 &[1, 16, 30, 30],
1074 &[1, 1],
1075 &[0, 0]
1076 ),
1077 Err(BackendError::InvalidArgument(
1078 "input_shape must have 4 elements (NCHW)".into()
1079 ))
1080 );
1081 }
1082
1083 #[test]
1084 fn conv2d_wrong_filter_rank() {
1085 let Some(b) = try_init() else {
1086 return;
1087 };
1088 assert_eq!(
1089 b.conv2d_forward(
1090 0,
1091 &[1, 3, 32, 32],
1092 0,
1093 &[16, 3, 3],
1094 0,
1095 &[1, 16, 30, 30],
1096 &[1, 1],
1097 &[0, 0]
1098 ),
1099 Err(BackendError::InvalidArgument(
1100 "filter_shape must have 4 elements (KCFHFW)".into()
1101 ))
1102 );
1103 }
1104
1105 #[test]
1108 fn init_idempotent() {
1109 let Some(mut b) = try_init() else {
1110 return;
1111 };
1112 assert_eq!(b.init(), Ok(()));
1113 assert!(b.is_initialized());
1114 }
1115
1116 #[test]
1119 fn alloc_copy_roundtrip() {
1120 let Some(b) = try_init() else {
1121 return;
1122 };
1123 let src: Vec<u8> = (0u8..64).collect();
1124 let handle = match b.alloc(src.len()) {
1125 Ok(h) => h,
1126 Err(_) => return,
1127 };
1128 b.copy_htod(handle, &src).expect("copy_htod");
1129 let mut dst = vec![0u8; src.len()];
1130 b.copy_dtoh(&mut dst, handle).expect("copy_dtoh");
1131 assert_eq!(src, dst);
1132 b.free(handle).expect("free");
1133 }
1134
1135 #[test]
1138 fn double_init_is_noop() {
1139 let Some(mut b) = try_init() else {
1140 return;
1141 };
1142 let first = b.is_initialized();
1143 let _ = b.init();
1144 assert_eq!(first, b.is_initialized());
1145 }
1146
1147 #[test]
1150 fn alloc_and_free_basic() {
1151 let Some(b) = try_init() else {
1152 return;
1153 };
1154 match b.alloc(128) {
1155 Ok(handle) => {
1156 assert!(handle > 0);
1157 b.free(handle).expect("free should succeed");
1158 }
1159 Err(_) => {
1160 }
1162 }
1163 }
1164
1165 fn upload_f32(b: &LevelZeroBackend, data: &[f32]) -> Option<u64> {
1169 let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_ne_bytes()).collect();
1170 let handle = b.alloc(bytes.len()).ok()?;
1171 b.copy_htod(handle, &bytes).ok()?;
1172 Some(handle)
1173 }
1174
1175 fn download_f32(b: &LevelZeroBackend, handle: u64, len: usize) -> Option<Vec<f32>> {
1177 let mut bytes = vec![0u8; len * 4];
1178 b.copy_dtoh(&mut bytes, handle).ok()?;
1179 Some(
1180 bytes
1181 .chunks_exact(4)
1182 .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
1183 .collect(),
1184 )
1185 }
1186
1187 #[test]
1188 fn l0_conv2d_identity_1x1() {
1189 let Some(b) = try_init() else {
1190 return;
1191 };
1192 let input: Vec<f32> = (0..16).map(|i| i as f32).collect();
1194 let filter = vec![1.0f32];
1195 let output_len = 16;
1196
1197 let Some(in_h) = upload_f32(&b, &input) else {
1198 return;
1199 };
1200 let Some(flt_h) = upload_f32(&b, &filter) else {
1201 return;
1202 };
1203 let Some(out_h) = b.alloc(output_len * 4).ok() else {
1204 return;
1205 };
1206
1207 let result = b.conv2d_forward(
1208 in_h,
1209 &[1, 1, 4, 4],
1210 flt_h,
1211 &[1, 1, 1, 1],
1212 out_h,
1213 &[1, 1, 4, 4],
1214 &[1, 1],
1215 &[0, 0],
1216 );
1217 assert!(result.is_ok(), "conv2d_forward failed: {result:?}");
1218
1219 if let Some(out) = download_f32(&b, out_h, output_len) {
1220 for (i, &val) in out.iter().enumerate() {
1221 assert!(
1222 (val - input[i]).abs() < 1e-5,
1223 "mismatch at {i}: expected {}, got {val}",
1224 input[i]
1225 );
1226 }
1227 }
1228
1229 let _ = b.free(in_h);
1230 let _ = b.free(flt_h);
1231 let _ = b.free(out_h);
1232 }
1233
1234 #[test]
1235 fn l0_conv2d_3x3_basic() {
1236 let Some(b) = try_init() else {
1237 return;
1238 };
1239 let input: Vec<f32> = (0..16).map(|i| i as f32).collect();
1241 let filter = vec![1.0f32; 9];
1243 let output_len = 4;
1244
1245 let Some(in_h) = upload_f32(&b, &input) else {
1246 return;
1247 };
1248 let Some(flt_h) = upload_f32(&b, &filter) else {
1249 return;
1250 };
1251 let Some(out_h) = b.alloc(output_len * 4).ok() else {
1252 return;
1253 };
1254
1255 let result = b.conv2d_forward(
1256 in_h,
1257 &[1, 1, 4, 4],
1258 flt_h,
1259 &[1, 1, 3, 3],
1260 out_h,
1261 &[1, 1, 2, 2],
1262 &[1, 1],
1263 &[0, 0],
1264 );
1265 assert!(result.is_ok());
1266
1267 let expected = [45.0f32, 54.0, 81.0, 90.0];
1273 if let Some(out) = download_f32(&b, out_h, output_len) {
1274 for (i, &val) in out.iter().enumerate() {
1275 assert!(
1276 (val - expected[i]).abs() < 1e-4,
1277 "mismatch at {i}: expected {}, got {val}",
1278 expected[i]
1279 );
1280 }
1281 }
1282
1283 let _ = b.free(in_h);
1284 let _ = b.free(flt_h);
1285 let _ = b.free(out_h);
1286 }
1287
1288 #[test]
1289 fn l0_conv2d_with_padding() {
1290 let Some(b) = try_init() else {
1291 return;
1292 };
1293 let input: Vec<f32> = (1..=9).map(|i| i as f32).collect();
1295 let filter = vec![1.0f32; 9];
1296 let output_len = 9;
1297
1298 let Some(in_h) = upload_f32(&b, &input) else {
1299 return;
1300 };
1301 let Some(flt_h) = upload_f32(&b, &filter) else {
1302 return;
1303 };
1304 let Some(out_h) = b.alloc(output_len * 4).ok() else {
1305 return;
1306 };
1307
1308 let result = b.conv2d_forward(
1309 in_h,
1310 &[1, 1, 3, 3],
1311 flt_h,
1312 &[1, 1, 3, 3],
1313 out_h,
1314 &[1, 1, 3, 3],
1315 &[1, 1],
1316 &[1, 1],
1317 );
1318 assert!(result.is_ok());
1319
1320 if let Some(out) = download_f32(&b, out_h, output_len) {
1322 assert!(
1323 (out[4] - 45.0).abs() < 1e-4,
1324 "center expected 45, got {}",
1325 out[4]
1326 );
1327 assert!(
1329 (out[0] - 12.0).abs() < 1e-4,
1330 "corner expected 12, got {}",
1331 out[0]
1332 );
1333 }
1334
1335 let _ = b.free(in_h);
1336 let _ = b.free(flt_h);
1337 let _ = b.free(out_h);
1338 }
1339
1340 #[test]
1343 fn l0_attention_uniform() {
1344 let Some(b) = try_init() else {
1345 return;
1346 };
1347 let seq_q = 2;
1350 let seq_kv = 2;
1351 let head_dim = 2;
1352 let q = vec![0.0f32; seq_q * head_dim];
1353 let k = vec![0.0f32; seq_kv * head_dim];
1354 let v = vec![1.0f32, 2.0, 3.0, 4.0]; let o_len = seq_q * head_dim;
1356
1357 let Some(q_h) = upload_f32(&b, &q) else {
1358 return;
1359 };
1360 let Some(k_h) = upload_f32(&b, &k) else {
1361 return;
1362 };
1363 let Some(v_h) = upload_f32(&b, &v) else {
1364 return;
1365 };
1366 let Some(o_h) = b.alloc(o_len * 4).ok() else {
1367 return;
1368 };
1369 let zeros = vec![0u8; o_len * 4];
1371 let _ = b.copy_htod(o_h, &zeros);
1372
1373 let scale = 1.0 / (head_dim as f64).sqrt();
1374 let result = b.attention(
1375 q_h, k_h, v_h, o_h, 1, 1, seq_q, seq_kv, head_dim, scale, false,
1376 );
1377 assert!(result.is_ok(), "attention failed: {result:?}");
1378
1379 if let Some(out) = download_f32(&b, o_h, o_len) {
1382 for sq_idx in 0..seq_q {
1384 let off = sq_idx * head_dim;
1385 assert!(
1386 (out[off] - 2.0).abs() < 1e-4,
1387 "q{sq_idx}[0] expected 2.0, got {}",
1388 out[off]
1389 );
1390 assert!(
1391 (out[off + 1] - 3.0).abs() < 1e-4,
1392 "q{sq_idx}[1] expected 3.0, got {}",
1393 out[off + 1]
1394 );
1395 }
1396 }
1397
1398 let _ = b.free(q_h);
1399 let _ = b.free(k_h);
1400 let _ = b.free(v_h);
1401 let _ = b.free(o_h);
1402 }
1403
1404 #[test]
1405 fn l0_attention_causal() {
1406 let Some(b) = try_init() else {
1407 return;
1408 };
1409 let seq_q = 2;
1411 let seq_kv = 2;
1412 let head_dim = 2;
1413 let q = vec![0.0f32; seq_q * head_dim];
1414 let k = vec![0.0f32; seq_kv * head_dim];
1415 let v = vec![1.0f32, 2.0, 3.0, 4.0];
1416 let o_len = seq_q * head_dim;
1417
1418 let Some(q_h) = upload_f32(&b, &q) else {
1419 return;
1420 };
1421 let Some(k_h) = upload_f32(&b, &k) else {
1422 return;
1423 };
1424 let Some(v_h) = upload_f32(&b, &v) else {
1425 return;
1426 };
1427 let Some(o_h) = b.alloc(o_len * 4).ok() else {
1428 return;
1429 };
1430 let zeros = vec![0u8; o_len * 4];
1431 let _ = b.copy_htod(o_h, &zeros);
1432
1433 let scale = 1.0 / (head_dim as f64).sqrt();
1434 let result = b.attention(
1435 q_h, k_h, v_h, o_h, 1, 1, seq_q, seq_kv, head_dim, scale, true,
1436 );
1437 assert!(result.is_ok());
1438
1439 if let Some(out) = download_f32(&b, o_h, o_len) {
1440 assert!(
1442 (out[0] - 1.0).abs() < 1e-4,
1443 "q0[0] expected 1.0, got {}",
1444 out[0]
1445 );
1446 assert!(
1447 (out[1] - 2.0).abs() < 1e-4,
1448 "q0[1] expected 2.0, got {}",
1449 out[1]
1450 );
1451 assert!(
1453 (out[2] - 2.0).abs() < 1e-4,
1454 "q1[0] expected 2.0, got {}",
1455 out[2]
1456 );
1457 assert!(
1458 (out[3] - 3.0).abs() < 1e-4,
1459 "q1[1] expected 3.0, got {}",
1460 out[3]
1461 );
1462 }
1463
1464 let _ = b.free(q_h);
1465 let _ = b.free(k_h);
1466 let _ = b.free(v_h);
1467 let _ = b.free(o_h);
1468 }
1469
1470 #[test]
1471 fn l0_attention_dominant_key() {
1472 let Some(b) = try_init() else {
1473 return;
1474 };
1475 let seq_q = 1;
1477 let seq_kv = 3;
1478 let head_dim = 2;
1479 let q = vec![10.0f32, 0.0];
1481 let k = vec![10.0f32, 0.0, 0.0, 0.0, 0.0, 0.0];
1483 let v = vec![1.0f32, 0.0, 0.0, 1.0, 0.0, 0.0]; let o_len = seq_q * head_dim;
1485
1486 let Some(q_h) = upload_f32(&b, &q) else {
1487 return;
1488 };
1489 let Some(k_h) = upload_f32(&b, &k) else {
1490 return;
1491 };
1492 let Some(v_h) = upload_f32(&b, &v) else {
1493 return;
1494 };
1495 let Some(o_h) = b.alloc(o_len * 4).ok() else {
1496 return;
1497 };
1498 let zeros = vec![0u8; o_len * 4];
1499 let _ = b.copy_htod(o_h, &zeros);
1500
1501 let scale = 1.0;
1502 let result = b.attention(
1503 q_h, k_h, v_h, o_h, 1, 1, seq_q, seq_kv, head_dim, scale, false,
1504 );
1505 assert!(result.is_ok());
1506
1507 if let Some(out) = download_f32(&b, o_h, o_len) {
1508 assert!(out[0] > 0.99, "expected output[0] ≈ 1.0, got {}", out[0]);
1511 assert!(out[1] < 0.01, "expected output[1] ≈ 0.0, got {}", out[1]);
1512 }
1513
1514 let _ = b.free(q_h);
1515 let _ = b.free(k_h);
1516 let _ = b.free(v_h);
1517 let _ = b.free(o_h);
1518 }
1519}