1use std::sync::Arc;
7
8use oxicuda_backend::{
9 BackendError, BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
10};
11
12use crate::{device::LevelZeroDevice, memory::LevelZeroMemoryManager};
13
14#[derive(Debug)]
32pub struct LevelZeroBackend {
33 device: Option<Arc<LevelZeroDevice>>,
34 memory: Option<Arc<LevelZeroMemoryManager>>,
35 initialized: bool,
36}
37
38impl LevelZeroBackend {
39 pub fn new() -> Self {
41 Self {
42 device: None,
43 memory: None,
44 initialized: false,
45 }
46 }
47
48 fn check_init(&self) -> BackendResult<()> {
50 if self.initialized {
51 Ok(())
52 } else {
53 Err(BackendError::NotInitialized)
54 }
55 }
56
57 fn memory(&self) -> BackendResult<&Arc<LevelZeroMemoryManager>> {
59 self.memory.as_ref().ok_or(BackendError::NotInitialized)
60 }
61}
62
63impl Default for LevelZeroBackend {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl ComputeBackend for LevelZeroBackend {
72 fn name(&self) -> &str {
73 "level-zero"
74 }
75
76 fn init(&mut self) -> BackendResult<()> {
77 if self.initialized {
78 return Ok(());
79 }
80 match LevelZeroDevice::new() {
81 Ok(dev) => {
82 let dev = Arc::new(dev);
83 tracing::info!("Level Zero backend initialised on: {}", dev.name());
84 let memory = LevelZeroMemoryManager::new(Arc::clone(&dev));
85 self.device = Some(dev);
86 self.memory = Some(Arc::new(memory));
87 self.initialized = true;
88 Ok(())
89 }
90 Err(e) => Err(BackendError::from(e)),
91 }
92 }
93
94 fn is_initialized(&self) -> bool {
95 self.initialized
96 }
97
98 fn gemm(
101 &self,
102 _trans_a: BackendTranspose,
103 _trans_b: BackendTranspose,
104 m: usize,
105 n: usize,
106 k: usize,
107 alpha: f64,
108 a_ptr: u64,
109 _lda: usize,
110 b_ptr: u64,
111 _ldb: usize,
112 beta: f64,
113 c_ptr: u64,
114 _ldc: usize,
115 ) -> BackendResult<()> {
116 self.check_init()?;
117 if m == 0 || n == 0 || k == 0 {
118 return Ok(());
119 }
120 self.dispatch_gemm(m, n, k, alpha as f32, a_ptr, b_ptr, beta as f32, c_ptr)
121 }
122
123 #[allow(clippy::too_many_arguments)]
124 fn batched_gemm(
125 &self,
126 _trans_a: BackendTranspose,
127 _trans_b: BackendTranspose,
128 m: usize,
129 n: usize,
130 k: usize,
131 alpha: f64,
132 a_ptr: u64,
133 _lda: usize,
134 stride_a: usize,
135 b_ptr: u64,
136 _ldb: usize,
137 stride_b: usize,
138 beta: f64,
139 c_ptr: u64,
140 _ldc: usize,
141 stride_c: usize,
142 batch_count: usize,
143 ) -> BackendResult<()> {
144 self.check_init()?;
145 if batch_count == 0 || m == 0 || n == 0 || k == 0 {
146 return Ok(());
147 }
148 self.dispatch_batched_gemm(
149 m,
150 n,
151 k,
152 alpha as f32,
153 a_ptr,
154 b_ptr,
155 beta as f32,
156 c_ptr,
157 batch_count,
158 stride_a,
159 stride_b,
160 stride_c,
161 )
162 }
163
164 fn conv2d_forward(
165 &self,
166 input_ptr: u64,
167 input_shape: &[usize],
168 filter_ptr: u64,
169 filter_shape: &[usize],
170 output_ptr: u64,
171 output_shape: &[usize],
172 stride: &[usize],
173 padding: &[usize],
174 ) -> BackendResult<()> {
175 self.check_init()?;
176
177 if input_shape.len() != 4 {
178 return Err(BackendError::InvalidArgument(
179 "input_shape must have 4 elements (NCHW)".into(),
180 ));
181 }
182 if filter_shape.len() != 4 {
183 return Err(BackendError::InvalidArgument(
184 "filter_shape must have 4 elements (KCFHFW)".into(),
185 ));
186 }
187 if output_shape.len() != 4 {
188 return Err(BackendError::InvalidArgument(
189 "output_shape must have 4 elements (NKOhOw)".into(),
190 ));
191 }
192 if stride.len() != 2 {
193 return Err(BackendError::InvalidArgument(
194 "stride must have 2 elements [sh, sw]".into(),
195 ));
196 }
197 if padding.len() != 2 {
198 return Err(BackendError::InvalidArgument(
199 "padding must have 2 elements [ph, pw]".into(),
200 ));
201 }
202
203 let n = input_shape[0];
204 let c_in = input_shape[1];
205 let h_in = input_shape[2];
206 let w_in = input_shape[3];
207 let k_out = filter_shape[0];
208 let fh = filter_shape[2];
209 let fw = filter_shape[3];
210 let o_h = output_shape[2];
211 let o_w = output_shape[3];
212 let stride_h = stride[0];
213 let stride_w = stride[1];
214 let pad_h = padding[0];
215 let pad_w = padding[1];
216
217 let in_len = n * c_in * h_in * w_in;
219 let flt_len = k_out * c_in * fh * fw;
220 let out_len = n * k_out * o_h * o_w;
221
222 let mut in_bytes = vec![0u8; in_len * 4];
223 self.copy_dtoh(&mut in_bytes, input_ptr)?;
224 let inp: Vec<f32> = in_bytes
225 .chunks_exact(4)
226 .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
227 .collect();
228
229 let mut flt_bytes = vec![0u8; flt_len * 4];
230 self.copy_dtoh(&mut flt_bytes, filter_ptr)?;
231 let flt: Vec<f32> = flt_bytes
232 .chunks_exact(4)
233 .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
234 .collect();
235
236 let mut out = vec![0.0f32; out_len];
238 for b_idx in 0..n {
239 for kf in 0..k_out {
240 for oy in 0..o_h {
241 for ox in 0..o_w {
242 let mut acc = 0.0f32;
243 for ci in 0..c_in {
244 for fy in 0..fh {
245 for fx in 0..fw {
246 let iy = (oy * stride_h + fy) as isize - pad_h as isize;
247 let ix = (ox * stride_w + fx) as isize - pad_w as isize;
248 if iy >= 0
249 && (iy as usize) < h_in
250 && ix >= 0
251 && (ix as usize) < w_in
252 {
253 let iy = iy as usize;
254 let ix = ix as usize;
255 acc += inp[((b_idx * c_in + ci) * h_in + iy) * w_in + ix]
256 * flt[((kf * c_in + ci) * fh + fy) * fw + fx];
257 }
258 }
259 }
260 }
261 out[((b_idx * k_out + kf) * o_h + oy) * o_w + ox] = acc;
262 }
263 }
264 }
265 }
266
267 let out_bytes: Vec<u8> = out.iter().flat_map(|f| f.to_ne_bytes()).collect();
268 self.copy_htod(output_ptr, &out_bytes)
269 }
270
271 fn attention(
272 &self,
273 q_ptr: u64,
274 k_ptr: u64,
275 v_ptr: u64,
276 o_ptr: u64,
277 batch: usize,
278 heads: usize,
279 seq_q: usize,
280 seq_kv: usize,
281 head_dim: usize,
282 scale: f64,
283 causal: bool,
284 ) -> BackendResult<()> {
285 self.check_init()?;
286
287 if seq_q == 0 || seq_kv == 0 || head_dim == 0 {
288 return Err(BackendError::InvalidArgument(
289 "seq_q, seq_kv, and head_dim must all be > 0".into(),
290 ));
291 }
292 if scale <= 0.0 || !scale.is_finite() {
293 return Err(BackendError::InvalidArgument(format!(
294 "scale must be a positive finite number, got {scale}"
295 )));
296 }
297
298 let batch_heads = batch * heads;
299 let scale_f32 = scale as f32;
300
301 let q_len = batch_heads * seq_q * head_dim;
303 let kv_len = batch_heads * seq_kv * head_dim;
304
305 let mut q_bytes = vec![0u8; q_len * 4];
306 self.copy_dtoh(&mut q_bytes, q_ptr)?;
307 let q: Vec<f32> = q_bytes
308 .chunks_exact(4)
309 .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
310 .collect();
311
312 let mut k_bytes = vec![0u8; kv_len * 4];
313 self.copy_dtoh(&mut k_bytes, k_ptr)?;
314 let k: Vec<f32> = k_bytes
315 .chunks_exact(4)
316 .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
317 .collect();
318
319 let mut v_bytes = vec![0u8; kv_len * 4];
320 self.copy_dtoh(&mut v_bytes, v_ptr)?;
321 let v: Vec<f32> = v_bytes
322 .chunks_exact(4)
323 .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
324 .collect();
325
326 let mut output = vec![0.0f32; q_len];
328
329 for bh in 0..batch_heads {
330 for sq in 0..seq_q {
331 let q_off = (bh * seq_q + sq) * head_dim;
332 let o_off = q_off;
333
334 let mut max_score = f32::NEG_INFINITY;
336 for sk in 0..seq_kv {
337 if causal && sk > sq {
338 continue;
339 }
340 let k_off = (bh * seq_kv + sk) * head_dim;
341 let mut dot = 0.0f32;
342 for d in 0..head_dim {
343 dot += q[q_off + d] * k[k_off + d];
344 }
345 let score = dot * scale_f32;
346 if score > max_score {
347 max_score = score;
348 }
349 }
350
351 if max_score == f32::NEG_INFINITY {
352 max_score = 0.0;
353 }
354
355 let mut sum_exp = 0.0f32;
357 for sk in 0..seq_kv {
358 if causal && sk > sq {
359 continue;
360 }
361 let k_off = (bh * seq_kv + sk) * head_dim;
362 let v_off = (bh * seq_kv + sk) * head_dim;
363 let mut dot = 0.0f32;
364 for d in 0..head_dim {
365 dot += q[q_off + d] * k[k_off + d];
366 }
367 let w = (dot * scale_f32 - max_score).exp();
368 sum_exp += w;
369 for d in 0..head_dim {
370 output[o_off + d] += w * v[v_off + d];
371 }
372 }
373
374 if sum_exp > 0.0 {
376 for d in 0..head_dim {
377 output[o_off + d] /= sum_exp;
378 }
379 }
380 }
381 }
382
383 let o_bytes: Vec<u8> = output.iter().flat_map(|f| f.to_ne_bytes()).collect();
384 self.copy_htod(o_ptr, &o_bytes)
385 }
386
387 fn reduce(
388 &self,
389 op: ReduceOp,
390 input_ptr: u64,
391 output_ptr: u64,
392 shape: &[usize],
393 axis: usize,
394 ) -> BackendResult<()> {
395 self.check_init()?;
396
397 if shape.is_empty() {
398 return Err(BackendError::InvalidArgument(
399 "shape must not be empty".into(),
400 ));
401 }
402 if axis >= shape.len() {
403 return Err(BackendError::InvalidArgument(format!(
404 "axis {axis} is out of bounds for shape of length {}",
405 shape.len()
406 )));
407 }
408
409 self.dispatch_reduce(op, input_ptr, output_ptr, shape, axis)
410 }
411
412 fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()> {
413 self.check_init()?;
414 if n == 0 {
415 return Ok(());
416 }
417 self.dispatch_unary(op, input_ptr, output_ptr, n)
418 }
419
420 fn binary(
421 &self,
422 op: BinaryOp,
423 a_ptr: u64,
424 b_ptr: u64,
425 output_ptr: u64,
426 n: usize,
427 ) -> BackendResult<()> {
428 self.check_init()?;
429 if n == 0 {
430 return Ok(());
431 }
432 self.dispatch_binary(op, a_ptr, b_ptr, output_ptr, n)
433 }
434
435 fn synchronize(&self) -> BackendResult<()> {
438 self.check_init()?;
439
440 #[cfg(any(target_os = "linux", target_os = "windows"))]
441 {
442 if let Some(dev) = &self.device {
443 let api = &dev.api;
444 let queue = dev.queue;
445 let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
448 if rc != 0 {
449 return Err(BackendError::DeviceError(format!(
450 "zeCommandQueueSynchronize failed: 0x{rc:08x}"
451 )));
452 }
453 }
454 }
455
456 Ok(())
457 }
458
459 fn alloc(&self, bytes: usize) -> BackendResult<u64> {
462 self.check_init()?;
463 if bytes == 0 {
464 return Err(BackendError::InvalidArgument(
465 "cannot allocate 0 bytes".into(),
466 ));
467 }
468 self.memory()?.alloc(bytes).map_err(BackendError::from)
469 }
470
471 fn free(&self, ptr: u64) -> BackendResult<()> {
472 self.check_init()?;
473 self.memory()?.free(ptr).map_err(BackendError::from)
474 }
475
476 fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
477 self.check_init()?;
478 if src.is_empty() {
479 return Ok(());
480 }
481 self.memory()?
482 .copy_to_device(dst, src)
483 .map_err(BackendError::from)
484 }
485
486 fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
487 self.check_init()?;
488 if dst.is_empty() {
489 return Ok(());
490 }
491 self.memory()?
492 .copy_from_device(dst, src)
493 .map_err(BackendError::from)
494 }
495}
496
497const WORKGROUP_SIZE: u32 = crate::spirv::WORKGROUP_SIZE;
501
502#[cfg_attr(not(any(target_os = "linux", target_os = "windows")), allow(dead_code))]
504enum KernelArg {
505 Buffer(u64),
507 U32(u32),
509 F32(f32),
511}
512
513impl LevelZeroBackend {
514 fn run_kernel(
522 &self,
523 spv_words: &[u32],
524 args: &[KernelArg],
525 workgroups: u32,
526 ) -> BackendResult<()> {
527 #[cfg(any(target_os = "linux", target_os = "windows"))]
528 {
529 use std::ffi::c_void;
530
531 use crate::device::{
532 ZE_MODULE_FORMAT_IL_SPIRV, ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
533 ZE_STRUCTURE_TYPE_KERNEL_DESC, ZE_STRUCTURE_TYPE_MODULE_DESC, ZeCommandListDesc,
534 ZeGroupCount, ZeKernelDesc, ZeKernelHandle, ZeModuleDesc, ZeModuleHandle,
535 };
536
537 let device = self.device.as_ref().ok_or(BackendError::NotInitialized)?;
538 let memory = self.memory()?;
539 let api = &device.api;
540 let context = device.context;
541 let dev_handle = device.device;
542 let queue = device.queue;
543
544 let spv_bytes: Vec<u8> = spv_words.iter().flat_map(|w| w.to_ne_bytes()).collect();
546
547 let module_desc = ZeModuleDesc {
549 stype: ZE_STRUCTURE_TYPE_MODULE_DESC,
550 p_next: std::ptr::null(),
551 format: ZE_MODULE_FORMAT_IL_SPIRV,
552 input_size: spv_bytes.len(),
553 p_input_module: spv_bytes.as_ptr(),
554 p_build_flags: std::ptr::null(),
555 p_constants: std::ptr::null(),
556 };
557 let mut module: ZeModuleHandle = std::ptr::null_mut();
558 let rc = unsafe {
559 (api.ze_module_create)(
560 context,
561 dev_handle,
562 &module_desc,
563 &mut module as *mut ZeModuleHandle,
564 std::ptr::null_mut(),
565 )
566 };
567 if rc != 0 {
568 return Err(BackendError::DeviceError(format!(
569 "zeModuleCreate failed: 0x{rc:08x}"
570 )));
571 }
572
573 let kernel_name = b"main\0";
575 let kernel_desc = ZeKernelDesc {
576 stype: ZE_STRUCTURE_TYPE_KERNEL_DESC,
577 p_next: std::ptr::null(),
578 flags: 0,
579 p_kernel_name: kernel_name.as_ptr(),
580 };
581 let mut kernel: ZeKernelHandle = std::ptr::null_mut();
582 let rc = unsafe {
583 (api.ze_kernel_create)(module, &kernel_desc, &mut kernel as *mut ZeKernelHandle)
584 };
585 if rc != 0 {
586 unsafe { (api.ze_module_destroy)(module) };
587 return Err(BackendError::DeviceError(format!(
588 "zeKernelCreate failed: 0x{rc:08x}"
589 )));
590 }
591
592 let rc = unsafe { (api.ze_kernel_set_group_size)(kernel, WORKGROUP_SIZE, 1, 1) };
594 if rc != 0 {
595 unsafe {
596 (api.ze_kernel_destroy)(kernel);
597 (api.ze_module_destroy)(module);
598 }
599 return Err(BackendError::DeviceError(format!(
600 "zeKernelSetGroupSize failed: 0x{rc:08x}"
601 )));
602 }
603
604 for (idx, arg) in args.iter().enumerate() {
606 let rc = match arg {
607 KernelArg::Buffer(handle) => {
608 let dev_ptr = memory.device_ptr(*handle).map_err(|e| {
609 unsafe {
610 (api.ze_kernel_destroy)(kernel);
611 (api.ze_module_destroy)(module);
612 }
613 BackendError::from(e)
614 })?;
615 unsafe {
616 (api.ze_kernel_set_argument_value)(
617 kernel,
618 idx as u32,
619 std::mem::size_of::<*mut c_void>(),
620 &dev_ptr as *const *mut c_void as *const c_void,
621 )
622 }
623 }
624 KernelArg::U32(val) => unsafe {
625 (api.ze_kernel_set_argument_value)(
626 kernel,
627 idx as u32,
628 std::mem::size_of::<u32>(),
629 val as *const u32 as *const c_void,
630 )
631 },
632 KernelArg::F32(val) => unsafe {
633 (api.ze_kernel_set_argument_value)(
634 kernel,
635 idx as u32,
636 std::mem::size_of::<f32>(),
637 val as *const f32 as *const c_void,
638 )
639 },
640 };
641 if rc != 0 {
642 unsafe {
643 (api.ze_kernel_destroy)(kernel);
644 (api.ze_module_destroy)(module);
645 }
646 return Err(BackendError::DeviceError(format!(
647 "zeKernelSetArgumentValue(arg={idx}) failed: 0x{rc:08x}"
648 )));
649 }
650 }
651
652 let list_desc = ZeCommandListDesc {
654 stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
655 p_next: std::ptr::null(),
656 command_queue_group_ordinal: 0,
657 flags: 0,
658 };
659 let mut list = std::ptr::null_mut();
660 let rc =
661 unsafe { (api.ze_command_list_create)(context, dev_handle, &list_desc, &mut list) };
662 if rc != 0 {
663 unsafe {
664 (api.ze_kernel_destroy)(kernel);
665 (api.ze_module_destroy)(module);
666 }
667 return Err(BackendError::DeviceError(format!(
668 "zeCommandListCreate failed: 0x{rc:08x}"
669 )));
670 }
671
672 let group_count = ZeGroupCount {
674 group_count_x: workgroups,
675 group_count_y: 1,
676 group_count_z: 1,
677 };
678 let rc = unsafe {
679 (api.ze_command_list_append_launch_kernel)(
680 list,
681 kernel,
682 &group_count,
683 0,
684 0,
685 std::ptr::null(),
686 )
687 };
688 if rc != 0 {
689 unsafe {
690 (api.ze_command_list_destroy)(list);
691 (api.ze_kernel_destroy)(kernel);
692 (api.ze_module_destroy)(module);
693 }
694 return Err(BackendError::DeviceError(format!(
695 "zeCommandListAppendLaunchKernel failed: 0x{rc:08x}"
696 )));
697 }
698
699 let rc = unsafe { (api.ze_command_list_close)(list) };
701 if rc != 0 {
702 unsafe {
703 (api.ze_command_list_destroy)(list);
704 (api.ze_kernel_destroy)(kernel);
705 (api.ze_module_destroy)(module);
706 }
707 return Err(BackendError::DeviceError(format!(
708 "zeCommandListClose failed: 0x{rc:08x}"
709 )));
710 }
711
712 let rc = unsafe { (api.ze_command_queue_execute_command_lists)(queue, 1, &list, 0) };
713 if rc != 0 {
714 unsafe {
715 (api.ze_command_list_destroy)(list);
716 (api.ze_kernel_destroy)(kernel);
717 (api.ze_module_destroy)(module);
718 }
719 return Err(BackendError::DeviceError(format!(
720 "zeCommandQueueExecuteCommandLists failed: 0x{rc:08x}"
721 )));
722 }
723
724 let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
725 if rc != 0 {
726 unsafe {
727 (api.ze_command_list_destroy)(list);
728 (api.ze_kernel_destroy)(kernel);
729 (api.ze_module_destroy)(module);
730 }
731 return Err(BackendError::DeviceError(format!(
732 "zeCommandQueueSynchronize failed: 0x{rc:08x}"
733 )));
734 }
735
736 unsafe {
738 (api.ze_command_list_destroy)(list);
739 (api.ze_kernel_destroy)(kernel);
740 (api.ze_module_destroy)(module);
741 }
742
743 Ok(())
744 }
745
746 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
747 {
748 let _ = (spv_words, args, workgroups);
749 Err(BackendError::DeviceError(
750 "Level Zero requires Linux or Windows".into(),
751 ))
752 }
753 }
754
755 fn dispatch_unary(
756 &self,
757 op: UnaryOp,
758 input_ptr: u64,
759 output_ptr: u64,
760 n: usize,
761 ) -> BackendResult<()> {
762 let spv = crate::spirv::unary_compute_shader(op);
763 let args = [
764 KernelArg::Buffer(input_ptr),
765 KernelArg::Buffer(output_ptr),
766 KernelArg::U32(n as u32),
767 ];
768 self.run_kernel(&spv, &args, (n as u32).div_ceil(WORKGROUP_SIZE))
769 }
770
771 fn dispatch_binary(
772 &self,
773 op: BinaryOp,
774 a_ptr: u64,
775 b_ptr: u64,
776 output_ptr: u64,
777 n: usize,
778 ) -> BackendResult<()> {
779 let spv = crate::spirv::binary_compute_shader(op);
780 let args = [
781 KernelArg::Buffer(a_ptr),
782 KernelArg::Buffer(b_ptr),
783 KernelArg::Buffer(output_ptr),
784 KernelArg::U32(n as u32),
785 ];
786 self.run_kernel(&spv, &args, (n as u32).div_ceil(WORKGROUP_SIZE))
787 }
788
789 fn dispatch_reduce(
790 &self,
791 op: ReduceOp,
792 input_ptr: u64,
793 output_ptr: u64,
794 shape: &[usize],
795 axis: usize,
796 ) -> BackendResult<()> {
797 let outer_size: usize = shape[..axis].iter().product::<usize>().max(1);
798 let reduce_size = shape[axis];
799 let inner_size: usize = shape[axis + 1..].iter().product::<usize>().max(1);
800
801 let spv = crate::spirv::reduce_compute_shader(op);
802 let total_output = (outer_size * inner_size) as u32;
803 let args = [
804 KernelArg::Buffer(input_ptr),
805 KernelArg::Buffer(output_ptr),
806 KernelArg::U32(outer_size as u32),
807 KernelArg::U32(reduce_size as u32),
808 KernelArg::U32(inner_size as u32),
809 ];
810 self.run_kernel(&spv, &args, total_output.div_ceil(WORKGROUP_SIZE))
811 }
812
813 fn run_kernel_3d(
818 &self,
819 spv_words: &[u32],
820 args: &[KernelArg],
821 workgroups_x: u32,
822 workgroups_y: u32,
823 workgroups_z: u32,
824 ) -> BackendResult<()> {
825 #[cfg(any(target_os = "linux", target_os = "windows"))]
826 {
827 use std::ffi::c_void;
828
829 use crate::device::{
830 ZE_MODULE_FORMAT_IL_SPIRV, ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
831 ZE_STRUCTURE_TYPE_KERNEL_DESC, ZE_STRUCTURE_TYPE_MODULE_DESC, ZeCommandListDesc,
832 ZeGroupCount, ZeKernelDesc, ZeKernelHandle, ZeModuleDesc, ZeModuleHandle,
833 };
834
835 let device = self.device.as_ref().ok_or(BackendError::NotInitialized)?;
836 let memory = self.memory()?;
837 let api = &device.api;
838 let context = device.context;
839 let dev_handle = device.device;
840 let queue = device.queue;
841
842 let spv_bytes: Vec<u8> = spv_words.iter().flat_map(|w| w.to_ne_bytes()).collect();
844
845 let module_desc = ZeModuleDesc {
847 stype: ZE_STRUCTURE_TYPE_MODULE_DESC,
848 p_next: std::ptr::null(),
849 format: ZE_MODULE_FORMAT_IL_SPIRV,
850 input_size: spv_bytes.len(),
851 p_input_module: spv_bytes.as_ptr(),
852 p_build_flags: std::ptr::null(),
853 p_constants: std::ptr::null(),
854 };
855 let mut module: ZeModuleHandle = std::ptr::null_mut();
856 let rc = unsafe {
857 (api.ze_module_create)(
858 context,
859 dev_handle,
860 &module_desc,
861 &mut module as *mut ZeModuleHandle,
862 std::ptr::null_mut(),
863 )
864 };
865 if rc != 0 {
866 return Err(BackendError::DeviceError(format!(
867 "zeModuleCreate failed: 0x{rc:08x}"
868 )));
869 }
870
871 let kernel_name = b"main\0";
873 let kernel_desc = ZeKernelDesc {
874 stype: ZE_STRUCTURE_TYPE_KERNEL_DESC,
875 p_next: std::ptr::null(),
876 flags: 0,
877 p_kernel_name: kernel_name.as_ptr(),
878 };
879 let mut kernel: ZeKernelHandle = std::ptr::null_mut();
880 let rc = unsafe {
881 (api.ze_kernel_create)(module, &kernel_desc, &mut kernel as *mut ZeKernelHandle)
882 };
883 if rc != 0 {
884 unsafe { (api.ze_module_destroy)(module) };
885 return Err(BackendError::DeviceError(format!(
886 "zeKernelCreate failed: 0x{rc:08x}"
887 )));
888 }
889
890 let rc = unsafe { (api.ze_kernel_set_group_size)(kernel, WORKGROUP_SIZE, 1, 1) };
892 if rc != 0 {
893 unsafe {
894 (api.ze_kernel_destroy)(kernel);
895 (api.ze_module_destroy)(module);
896 }
897 return Err(BackendError::DeviceError(format!(
898 "zeKernelSetGroupSize failed: 0x{rc:08x}"
899 )));
900 }
901
902 for (idx, arg) in args.iter().enumerate() {
904 let rc = match arg {
905 KernelArg::Buffer(handle) => {
906 let dev_ptr = memory.device_ptr(*handle).map_err(|e| {
907 unsafe {
908 (api.ze_kernel_destroy)(kernel);
909 (api.ze_module_destroy)(module);
910 }
911 BackendError::from(e)
912 })?;
913 unsafe {
914 (api.ze_kernel_set_argument_value)(
915 kernel,
916 idx as u32,
917 std::mem::size_of::<*mut c_void>(),
918 &dev_ptr as *const *mut c_void as *const c_void,
919 )
920 }
921 }
922 KernelArg::U32(val) => unsafe {
923 (api.ze_kernel_set_argument_value)(
924 kernel,
925 idx as u32,
926 std::mem::size_of::<u32>(),
927 val as *const u32 as *const c_void,
928 )
929 },
930 KernelArg::F32(val) => unsafe {
931 (api.ze_kernel_set_argument_value)(
932 kernel,
933 idx as u32,
934 std::mem::size_of::<f32>(),
935 val as *const f32 as *const c_void,
936 )
937 },
938 };
939 if rc != 0 {
940 unsafe {
941 (api.ze_kernel_destroy)(kernel);
942 (api.ze_module_destroy)(module);
943 }
944 return Err(BackendError::DeviceError(format!(
945 "zeKernelSetArgumentValue(arg={idx}) failed: 0x{rc:08x}"
946 )));
947 }
948 }
949
950 let list_desc = ZeCommandListDesc {
952 stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
953 p_next: std::ptr::null(),
954 command_queue_group_ordinal: 0,
955 flags: 0,
956 };
957 let mut list = std::ptr::null_mut();
958 let rc =
959 unsafe { (api.ze_command_list_create)(context, dev_handle, &list_desc, &mut list) };
960 if rc != 0 {
961 unsafe {
962 (api.ze_kernel_destroy)(kernel);
963 (api.ze_module_destroy)(module);
964 }
965 return Err(BackendError::DeviceError(format!(
966 "zeCommandListCreate failed: 0x{rc:08x}"
967 )));
968 }
969
970 let group_count = ZeGroupCount {
972 group_count_x: workgroups_x,
973 group_count_y: workgroups_y,
974 group_count_z: workgroups_z,
975 };
976 let rc = unsafe {
977 (api.ze_command_list_append_launch_kernel)(
978 list,
979 kernel,
980 &group_count,
981 0,
982 0,
983 std::ptr::null(),
984 )
985 };
986 if rc != 0 {
987 unsafe {
988 (api.ze_command_list_destroy)(list);
989 (api.ze_kernel_destroy)(kernel);
990 (api.ze_module_destroy)(module);
991 }
992 return Err(BackendError::DeviceError(format!(
993 "zeCommandListAppendLaunchKernel failed: 0x{rc:08x}"
994 )));
995 }
996
997 let rc = unsafe { (api.ze_command_list_close)(list) };
999 if rc != 0 {
1000 unsafe {
1001 (api.ze_command_list_destroy)(list);
1002 (api.ze_kernel_destroy)(kernel);
1003 (api.ze_module_destroy)(module);
1004 }
1005 return Err(BackendError::DeviceError(format!(
1006 "zeCommandListClose failed: 0x{rc:08x}"
1007 )));
1008 }
1009
1010 let rc = unsafe { (api.ze_command_queue_execute_command_lists)(queue, 1, &list, 0) };
1011 if rc != 0 {
1012 unsafe {
1013 (api.ze_command_list_destroy)(list);
1014 (api.ze_kernel_destroy)(kernel);
1015 (api.ze_module_destroy)(module);
1016 }
1017 return Err(BackendError::DeviceError(format!(
1018 "zeCommandQueueExecuteCommandLists failed: 0x{rc:08x}"
1019 )));
1020 }
1021
1022 let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
1023 if rc != 0 {
1024 unsafe {
1025 (api.ze_command_list_destroy)(list);
1026 (api.ze_kernel_destroy)(kernel);
1027 (api.ze_module_destroy)(module);
1028 }
1029 return Err(BackendError::DeviceError(format!(
1030 "zeCommandQueueSynchronize failed: 0x{rc:08x}"
1031 )));
1032 }
1033
1034 unsafe {
1036 (api.ze_command_list_destroy)(list);
1037 (api.ze_kernel_destroy)(kernel);
1038 (api.ze_module_destroy)(module);
1039 }
1040
1041 Ok(())
1042 }
1043
1044 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
1045 {
1046 let _ = (spv_words, args, workgroups_x, workgroups_y, workgroups_z);
1047 Err(BackendError::DeviceError(
1048 "Level Zero requires Linux or Windows".into(),
1049 ))
1050 }
1051 }
1052
1053 #[allow(clippy::too_many_arguments)]
1054 fn dispatch_batched_gemm(
1055 &self,
1056 m: usize,
1057 n: usize,
1058 k: usize,
1059 alpha: f32,
1060 a_ptr: u64,
1061 b_ptr: u64,
1062 beta: f32,
1063 c_ptr: u64,
1064 batch_count: usize,
1065 stride_a: usize,
1066 stride_b: usize,
1067 stride_c: usize,
1068 ) -> BackendResult<()> {
1069 let spv = crate::spirv::batched_gemm_compute_shader();
1070 let total_per_batch = (m * n) as u32;
1071 let workgroups_x = total_per_batch.div_ceil(WORKGROUP_SIZE);
1072 let args = [
1073 KernelArg::Buffer(a_ptr),
1074 KernelArg::Buffer(b_ptr),
1075 KernelArg::Buffer(c_ptr),
1076 KernelArg::U32(m as u32),
1077 KernelArg::U32(n as u32),
1078 KernelArg::U32(k as u32),
1079 KernelArg::F32(alpha),
1080 KernelArg::F32(beta),
1081 KernelArg::U32(batch_count as u32),
1082 KernelArg::U32(stride_a as u32),
1083 KernelArg::U32(stride_b as u32),
1084 KernelArg::U32(stride_c as u32),
1085 ];
1086 self.run_kernel_3d(&spv, &args, workgroups_x, 1, batch_count as u32)
1087 }
1088
1089 #[allow(clippy::too_many_arguments)]
1090 fn dispatch_gemm(
1091 &self,
1092 m: usize,
1093 n: usize,
1094 k: usize,
1095 alpha: f32,
1096 a_ptr: u64,
1097 b_ptr: u64,
1098 beta: f32,
1099 c_ptr: u64,
1100 ) -> BackendResult<()> {
1101 let spv = crate::spirv::gemm_compute_shader();
1102 let total = (m * n) as u32;
1103 let args = [
1104 KernelArg::Buffer(a_ptr),
1105 KernelArg::Buffer(b_ptr),
1106 KernelArg::Buffer(c_ptr),
1107 KernelArg::U32(m as u32),
1108 KernelArg::U32(n as u32),
1109 KernelArg::U32(k as u32),
1110 KernelArg::F32(alpha),
1111 KernelArg::F32(beta),
1112 ];
1113 self.run_kernel(&spv, &args, total.div_ceil(WORKGROUP_SIZE))
1114 }
1115}
1116
1117#[cfg(test)]
1120mod tests {
1121 use super::*;
1122 use oxicuda_backend::{BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp};
1123
1124 #[test]
1127 fn level_zero_backend_new_uninitialized() {
1128 let b = LevelZeroBackend::new();
1129 assert!(!b.is_initialized());
1130 }
1131
1132 #[test]
1133 fn level_zero_backend_name() {
1134 let b = LevelZeroBackend::new();
1135 assert_eq!(b.name(), "level-zero");
1136 }
1137
1138 #[test]
1139 fn level_zero_backend_default() {
1140 let b = LevelZeroBackend::default();
1141 assert!(!b.is_initialized());
1142 assert_eq!(b.name(), "level-zero");
1143 }
1144
1145 #[test]
1146 fn backend_debug_impl() {
1147 let b = LevelZeroBackend::new();
1148 let s = format!("{b:?}");
1149 assert!(s.contains("LevelZeroBackend"));
1150 }
1151
1152 #[test]
1155 fn backend_object_safe() {
1156 let b: Box<dyn ComputeBackend> = Box::new(LevelZeroBackend::new());
1157 assert_eq!(b.name(), "level-zero");
1158 }
1159
1160 #[test]
1163 fn backend_not_initialized_gemm() {
1164 let b = LevelZeroBackend::new();
1165 let result = b.gemm(
1166 BackendTranspose::NoTrans,
1167 BackendTranspose::NoTrans,
1168 4,
1169 4,
1170 4,
1171 1.0,
1172 0,
1173 4,
1174 0,
1175 4,
1176 0.0,
1177 0,
1178 4,
1179 );
1180 assert_eq!(result, Err(BackendError::NotInitialized));
1181 }
1182
1183 #[test]
1184 fn backend_not_initialized_batched_gemm() {
1185 let b = LevelZeroBackend::new();
1186 let result = b.batched_gemm(
1187 BackendTranspose::NoTrans,
1188 BackendTranspose::NoTrans,
1189 4,
1190 4,
1191 4,
1192 1.0,
1193 0,
1194 4,
1195 16,
1196 0,
1197 4,
1198 16,
1199 0.0,
1200 0,
1201 4,
1202 16,
1203 2,
1204 );
1205 assert_eq!(result, Err(BackendError::NotInitialized));
1206 }
1207
1208 #[test]
1209 fn backend_not_initialized_alloc() {
1210 let b = LevelZeroBackend::new();
1211 assert_eq!(b.alloc(1024), Err(BackendError::NotInitialized));
1212 }
1213
1214 #[test]
1215 fn backend_not_initialized_synchronize() {
1216 let b = LevelZeroBackend::new();
1217 assert_eq!(b.synchronize(), Err(BackendError::NotInitialized));
1218 }
1219
1220 #[test]
1221 fn backend_not_initialized_free() {
1222 let b = LevelZeroBackend::new();
1223 assert_eq!(b.free(1), Err(BackendError::NotInitialized));
1224 }
1225
1226 #[test]
1227 fn backend_not_initialized_copy_htod() {
1228 let b = LevelZeroBackend::new();
1229 assert_eq!(b.copy_htod(1, b"hello"), Err(BackendError::NotInitialized));
1230 }
1231
1232 #[test]
1233 fn backend_not_initialized_copy_dtoh() {
1234 let b = LevelZeroBackend::new();
1235 let mut buf = [0u8; 4];
1236 assert_eq!(b.copy_dtoh(&mut buf, 1), Err(BackendError::NotInitialized));
1237 }
1238
1239 fn try_init() -> Option<LevelZeroBackend> {
1242 let mut b = LevelZeroBackend::new();
1243 match b.init() {
1244 Ok(()) => Some(b),
1245 Err(_) => None,
1246 }
1247 }
1248
1249 #[test]
1252 fn init_graceful_failure() {
1253 let mut b = LevelZeroBackend::new();
1255 let _result = b.init();
1256 }
1258
1259 #[test]
1262 fn alloc_zero_bytes_error() {
1263 let Some(b) = try_init() else {
1264 return;
1265 };
1266 assert_eq!(
1267 b.alloc(0),
1268 Err(BackendError::InvalidArgument(
1269 "cannot allocate 0 bytes".into()
1270 ))
1271 );
1272 }
1273
1274 #[test]
1275 fn copy_htod_empty_noop() {
1276 let Some(b) = try_init() else {
1277 return;
1278 };
1279 assert_eq!(b.copy_htod(0, &[]), Ok(()));
1280 }
1281
1282 #[test]
1283 fn copy_dtoh_empty_noop() {
1284 let Some(b) = try_init() else {
1285 return;
1286 };
1287 assert_eq!(b.copy_dtoh(&mut [], 0), Ok(()));
1288 }
1289
1290 #[test]
1291 fn gemm_zero_dims_noop() {
1292 let Some(b) = try_init() else {
1293 return;
1294 };
1295 assert_eq!(
1296 b.gemm(
1297 BackendTranspose::NoTrans,
1298 BackendTranspose::NoTrans,
1299 0,
1300 0,
1301 0,
1302 1.0,
1303 0,
1304 1,
1305 0,
1306 1,
1307 0.0,
1308 0,
1309 1
1310 ),
1311 Ok(())
1312 );
1313 }
1314
1315 #[test]
1316 fn batched_gemm_zero_batch_noop() {
1317 let Some(b) = try_init() else {
1318 return;
1319 };
1320 assert_eq!(
1321 b.batched_gemm(
1322 BackendTranspose::NoTrans,
1323 BackendTranspose::NoTrans,
1324 4,
1325 4,
1326 4,
1327 1.0,
1328 0,
1329 4,
1330 16,
1331 0,
1332 4,
1333 16,
1334 0.0,
1335 0,
1336 4,
1337 16,
1338 0,
1339 ),
1340 Ok(())
1341 );
1342 }
1343
1344 #[test]
1345 fn batched_gemm_zero_dims_noop() {
1346 let Some(b) = try_init() else {
1347 return;
1348 };
1349 assert_eq!(
1350 b.batched_gemm(
1351 BackendTranspose::NoTrans,
1352 BackendTranspose::NoTrans,
1353 0,
1354 0,
1355 0,
1356 1.0,
1357 0,
1358 1,
1359 0,
1360 0,
1361 1,
1362 0,
1363 0.0,
1364 0,
1365 1,
1366 0,
1367 3,
1368 ),
1369 Ok(())
1370 );
1371 }
1372
1373 #[test]
1374 fn unary_zero_n_noop() {
1375 let Some(b) = try_init() else {
1376 return;
1377 };
1378 assert_eq!(b.unary(UnaryOp::Relu, 0, 0, 0), Ok(()));
1379 }
1380
1381 #[test]
1382 fn binary_zero_n_noop() {
1383 let Some(b) = try_init() else {
1384 return;
1385 };
1386 assert_eq!(b.binary(BinaryOp::Add, 0, 0, 0, 0), Ok(()));
1387 }
1388
1389 #[test]
1390 fn synchronize_after_init() {
1391 let Some(b) = try_init() else {
1392 return;
1393 };
1394 assert_eq!(b.synchronize(), Ok(()));
1395 }
1396
1397 #[test]
1400 fn reduce_empty_shape_error() {
1401 let Some(b) = try_init() else {
1402 return;
1403 };
1404 assert_eq!(
1405 b.reduce(ReduceOp::Sum, 0, 0, &[], 0),
1406 Err(BackendError::InvalidArgument(
1407 "shape must not be empty".into()
1408 ))
1409 );
1410 }
1411
1412 #[test]
1413 fn reduce_axis_out_of_bounds_error() {
1414 let Some(b) = try_init() else {
1415 return;
1416 };
1417 assert_eq!(
1418 b.reduce(ReduceOp::Sum, 0, 0, &[4, 4], 5),
1419 Err(BackendError::InvalidArgument(
1420 "axis 5 is out of bounds for shape of length 2".into()
1421 ))
1422 );
1423 }
1424
1425 #[test]
1426 fn attention_zero_seq_error() {
1427 let Some(b) = try_init() else {
1428 return;
1429 };
1430 assert_eq!(
1431 b.attention(0, 0, 0, 0, 1, 1, 0, 8, 64, 0.125, false),
1432 Err(BackendError::InvalidArgument(
1433 "seq_q, seq_kv, and head_dim must all be > 0".into()
1434 ))
1435 );
1436 }
1437
1438 #[test]
1439 fn attention_invalid_scale_error() {
1440 let Some(b) = try_init() else {
1441 return;
1442 };
1443 assert_eq!(
1444 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, 0.0, false),
1445 Err(BackendError::InvalidArgument(
1446 "scale must be a positive finite number, got 0".into()
1447 ))
1448 );
1449 assert_eq!(
1450 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, -1.0, false),
1451 Err(BackendError::InvalidArgument(
1452 "scale must be a positive finite number, got -1".into()
1453 ))
1454 );
1455 assert!(
1456 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, f64::INFINITY, false)
1457 .is_err()
1458 );
1459 }
1460
1461 #[test]
1462 fn conv2d_wrong_input_rank() {
1463 let Some(b) = try_init() else {
1464 return;
1465 };
1466 assert_eq!(
1467 b.conv2d_forward(
1468 0,
1469 &[1, 3, 32],
1470 0,
1471 &[16, 3, 3, 3],
1472 0,
1473 &[1, 16, 30, 30],
1474 &[1, 1],
1475 &[0, 0]
1476 ),
1477 Err(BackendError::InvalidArgument(
1478 "input_shape must have 4 elements (NCHW)".into()
1479 ))
1480 );
1481 }
1482
1483 #[test]
1484 fn conv2d_wrong_filter_rank() {
1485 let Some(b) = try_init() else {
1486 return;
1487 };
1488 assert_eq!(
1489 b.conv2d_forward(
1490 0,
1491 &[1, 3, 32, 32],
1492 0,
1493 &[16, 3, 3],
1494 0,
1495 &[1, 16, 30, 30],
1496 &[1, 1],
1497 &[0, 0]
1498 ),
1499 Err(BackendError::InvalidArgument(
1500 "filter_shape must have 4 elements (KCFHFW)".into()
1501 ))
1502 );
1503 }
1504
1505 #[test]
1508 fn init_idempotent() {
1509 let Some(mut b) = try_init() else {
1510 return;
1511 };
1512 assert_eq!(b.init(), Ok(()));
1513 assert!(b.is_initialized());
1514 }
1515
1516 #[test]
1519 fn alloc_copy_roundtrip() {
1520 let Some(b) = try_init() else {
1521 return;
1522 };
1523 let src: Vec<u8> = (0u8..64).collect();
1524 let handle = match b.alloc(src.len()) {
1525 Ok(h) => h,
1526 Err(_) => return,
1527 };
1528 b.copy_htod(handle, &src).expect("copy_htod");
1529 let mut dst = vec![0u8; src.len()];
1530 b.copy_dtoh(&mut dst, handle).expect("copy_dtoh");
1531 assert_eq!(src, dst);
1532 b.free(handle).expect("free");
1533 }
1534
1535 #[test]
1538 fn double_init_is_noop() {
1539 let Some(mut b) = try_init() else {
1540 return;
1541 };
1542 let first = b.is_initialized();
1543 let _ = b.init();
1544 assert_eq!(first, b.is_initialized());
1545 }
1546
1547 #[test]
1550 fn alloc_and_free_basic() {
1551 let Some(b) = try_init() else {
1552 return;
1553 };
1554 match b.alloc(128) {
1555 Ok(handle) => {
1556 assert!(handle > 0);
1557 b.free(handle).expect("free should succeed");
1558 }
1559 Err(_) => {
1560 }
1562 }
1563 }
1564
1565 fn upload_f32(b: &LevelZeroBackend, data: &[f32]) -> Option<u64> {
1569 let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_ne_bytes()).collect();
1570 let handle = b.alloc(bytes.len()).ok()?;
1571 b.copy_htod(handle, &bytes).ok()?;
1572 Some(handle)
1573 }
1574
1575 fn download_f32(b: &LevelZeroBackend, handle: u64, len: usize) -> Option<Vec<f32>> {
1577 let mut bytes = vec![0u8; len * 4];
1578 b.copy_dtoh(&mut bytes, handle).ok()?;
1579 Some(
1580 bytes
1581 .chunks_exact(4)
1582 .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
1583 .collect(),
1584 )
1585 }
1586
1587 #[test]
1588 fn l0_conv2d_identity_1x1() {
1589 let Some(b) = try_init() else {
1590 return;
1591 };
1592 let input: Vec<f32> = (0..16).map(|i| i as f32).collect();
1594 let filter = vec![1.0f32];
1595 let output_len = 16;
1596
1597 let Some(in_h) = upload_f32(&b, &input) else {
1598 return;
1599 };
1600 let Some(flt_h) = upload_f32(&b, &filter) else {
1601 return;
1602 };
1603 let Some(out_h) = b.alloc(output_len * 4).ok() else {
1604 return;
1605 };
1606
1607 let result = b.conv2d_forward(
1608 in_h,
1609 &[1, 1, 4, 4],
1610 flt_h,
1611 &[1, 1, 1, 1],
1612 out_h,
1613 &[1, 1, 4, 4],
1614 &[1, 1],
1615 &[0, 0],
1616 );
1617 assert!(result.is_ok(), "conv2d_forward failed: {result:?}");
1618
1619 if let Some(out) = download_f32(&b, out_h, output_len) {
1620 for (i, &val) in out.iter().enumerate() {
1621 assert!(
1622 (val - input[i]).abs() < 1e-5,
1623 "mismatch at {i}: expected {}, got {val}",
1624 input[i]
1625 );
1626 }
1627 }
1628
1629 let _ = b.free(in_h);
1630 let _ = b.free(flt_h);
1631 let _ = b.free(out_h);
1632 }
1633
1634 #[test]
1635 fn l0_conv2d_3x3_basic() {
1636 let Some(b) = try_init() else {
1637 return;
1638 };
1639 let input: Vec<f32> = (0..16).map(|i| i as f32).collect();
1641 let filter = vec![1.0f32; 9];
1643 let output_len = 4;
1644
1645 let Some(in_h) = upload_f32(&b, &input) else {
1646 return;
1647 };
1648 let Some(flt_h) = upload_f32(&b, &filter) else {
1649 return;
1650 };
1651 let Some(out_h) = b.alloc(output_len * 4).ok() else {
1652 return;
1653 };
1654
1655 let result = b.conv2d_forward(
1656 in_h,
1657 &[1, 1, 4, 4],
1658 flt_h,
1659 &[1, 1, 3, 3],
1660 out_h,
1661 &[1, 1, 2, 2],
1662 &[1, 1],
1663 &[0, 0],
1664 );
1665 assert!(result.is_ok());
1666
1667 let expected = [45.0f32, 54.0, 81.0, 90.0];
1673 if let Some(out) = download_f32(&b, out_h, output_len) {
1674 for (i, &val) in out.iter().enumerate() {
1675 assert!(
1676 (val - expected[i]).abs() < 1e-4,
1677 "mismatch at {i}: expected {}, got {val}",
1678 expected[i]
1679 );
1680 }
1681 }
1682
1683 let _ = b.free(in_h);
1684 let _ = b.free(flt_h);
1685 let _ = b.free(out_h);
1686 }
1687
1688 #[test]
1689 fn l0_conv2d_with_padding() {
1690 let Some(b) = try_init() else {
1691 return;
1692 };
1693 let input: Vec<f32> = (1..=9).map(|i| i as f32).collect();
1695 let filter = vec![1.0f32; 9];
1696 let output_len = 9;
1697
1698 let Some(in_h) = upload_f32(&b, &input) else {
1699 return;
1700 };
1701 let Some(flt_h) = upload_f32(&b, &filter) else {
1702 return;
1703 };
1704 let Some(out_h) = b.alloc(output_len * 4).ok() else {
1705 return;
1706 };
1707
1708 let result = b.conv2d_forward(
1709 in_h,
1710 &[1, 1, 3, 3],
1711 flt_h,
1712 &[1, 1, 3, 3],
1713 out_h,
1714 &[1, 1, 3, 3],
1715 &[1, 1],
1716 &[1, 1],
1717 );
1718 assert!(result.is_ok());
1719
1720 if let Some(out) = download_f32(&b, out_h, output_len) {
1722 assert!(
1723 (out[4] - 45.0).abs() < 1e-4,
1724 "center expected 45, got {}",
1725 out[4]
1726 );
1727 assert!(
1729 (out[0] - 12.0).abs() < 1e-4,
1730 "corner expected 12, got {}",
1731 out[0]
1732 );
1733 }
1734
1735 let _ = b.free(in_h);
1736 let _ = b.free(flt_h);
1737 let _ = b.free(out_h);
1738 }
1739
1740 #[test]
1743 fn l0_attention_uniform() {
1744 let Some(b) = try_init() else {
1745 return;
1746 };
1747 let seq_q = 2;
1750 let seq_kv = 2;
1751 let head_dim = 2;
1752 let q = vec![0.0f32; seq_q * head_dim];
1753 let k = vec![0.0f32; seq_kv * head_dim];
1754 let v = vec![1.0f32, 2.0, 3.0, 4.0]; let o_len = seq_q * head_dim;
1756
1757 let Some(q_h) = upload_f32(&b, &q) else {
1758 return;
1759 };
1760 let Some(k_h) = upload_f32(&b, &k) else {
1761 return;
1762 };
1763 let Some(v_h) = upload_f32(&b, &v) else {
1764 return;
1765 };
1766 let Some(o_h) = b.alloc(o_len * 4).ok() else {
1767 return;
1768 };
1769 let zeros = vec![0u8; o_len * 4];
1771 let _ = b.copy_htod(o_h, &zeros);
1772
1773 let scale = 1.0 / (head_dim as f64).sqrt();
1774 let result = b.attention(
1775 q_h, k_h, v_h, o_h, 1, 1, seq_q, seq_kv, head_dim, scale, false,
1776 );
1777 assert!(result.is_ok(), "attention failed: {result:?}");
1778
1779 if let Some(out) = download_f32(&b, o_h, o_len) {
1782 for sq_idx in 0..seq_q {
1784 let off = sq_idx * head_dim;
1785 assert!(
1786 (out[off] - 2.0).abs() < 1e-4,
1787 "q{sq_idx}[0] expected 2.0, got {}",
1788 out[off]
1789 );
1790 assert!(
1791 (out[off + 1] - 3.0).abs() < 1e-4,
1792 "q{sq_idx}[1] expected 3.0, got {}",
1793 out[off + 1]
1794 );
1795 }
1796 }
1797
1798 let _ = b.free(q_h);
1799 let _ = b.free(k_h);
1800 let _ = b.free(v_h);
1801 let _ = b.free(o_h);
1802 }
1803
1804 #[test]
1805 fn l0_attention_causal() {
1806 let Some(b) = try_init() else {
1807 return;
1808 };
1809 let seq_q = 2;
1811 let seq_kv = 2;
1812 let head_dim = 2;
1813 let q = vec![0.0f32; seq_q * head_dim];
1814 let k = vec![0.0f32; seq_kv * head_dim];
1815 let v = vec![1.0f32, 2.0, 3.0, 4.0];
1816 let o_len = seq_q * head_dim;
1817
1818 let Some(q_h) = upload_f32(&b, &q) else {
1819 return;
1820 };
1821 let Some(k_h) = upload_f32(&b, &k) else {
1822 return;
1823 };
1824 let Some(v_h) = upload_f32(&b, &v) else {
1825 return;
1826 };
1827 let Some(o_h) = b.alloc(o_len * 4).ok() else {
1828 return;
1829 };
1830 let zeros = vec![0u8; o_len * 4];
1831 let _ = b.copy_htod(o_h, &zeros);
1832
1833 let scale = 1.0 / (head_dim as f64).sqrt();
1834 let result = b.attention(
1835 q_h, k_h, v_h, o_h, 1, 1, seq_q, seq_kv, head_dim, scale, true,
1836 );
1837 assert!(result.is_ok());
1838
1839 if let Some(out) = download_f32(&b, o_h, o_len) {
1840 assert!(
1842 (out[0] - 1.0).abs() < 1e-4,
1843 "q0[0] expected 1.0, got {}",
1844 out[0]
1845 );
1846 assert!(
1847 (out[1] - 2.0).abs() < 1e-4,
1848 "q0[1] expected 2.0, got {}",
1849 out[1]
1850 );
1851 assert!(
1853 (out[2] - 2.0).abs() < 1e-4,
1854 "q1[0] expected 2.0, got {}",
1855 out[2]
1856 );
1857 assert!(
1858 (out[3] - 3.0).abs() < 1e-4,
1859 "q1[1] expected 3.0, got {}",
1860 out[3]
1861 );
1862 }
1863
1864 let _ = b.free(q_h);
1865 let _ = b.free(k_h);
1866 let _ = b.free(v_h);
1867 let _ = b.free(o_h);
1868 }
1869
1870 #[test]
1871 fn l0_attention_dominant_key() {
1872 let Some(b) = try_init() else {
1873 return;
1874 };
1875 let seq_q = 1;
1877 let seq_kv = 3;
1878 let head_dim = 2;
1879 let q = vec![10.0f32, 0.0];
1881 let k = vec![10.0f32, 0.0, 0.0, 0.0, 0.0, 0.0];
1883 let v = vec![1.0f32, 0.0, 0.0, 1.0, 0.0, 0.0]; let o_len = seq_q * head_dim;
1885
1886 let Some(q_h) = upload_f32(&b, &q) else {
1887 return;
1888 };
1889 let Some(k_h) = upload_f32(&b, &k) else {
1890 return;
1891 };
1892 let Some(v_h) = upload_f32(&b, &v) else {
1893 return;
1894 };
1895 let Some(o_h) = b.alloc(o_len * 4).ok() else {
1896 return;
1897 };
1898 let zeros = vec![0u8; o_len * 4];
1899 let _ = b.copy_htod(o_h, &zeros);
1900
1901 let scale = 1.0;
1902 let result = b.attention(
1903 q_h, k_h, v_h, o_h, 1, 1, seq_q, seq_kv, head_dim, scale, false,
1904 );
1905 assert!(result.is_ok());
1906
1907 if let Some(out) = download_f32(&b, o_h, o_len) {
1908 assert!(out[0] > 0.99, "expected output[0] ≈ 1.0, got {}", out[0]);
1911 assert!(out[1] < 0.01, "expected output[1] ≈ 0.0, got {}", out[1]);
1912 }
1913
1914 let _ = b.free(q_h);
1915 let _ = b.free(k_h);
1916 let _ = b.free(v_h);
1917 let _ = b.free(o_h);
1918 }
1919}