1#![allow(
10 clippy::missing_errors_doc,
11 clippy::missing_panics_doc,
12 clippy::module_name_repetitions,
13 clippy::must_use_candidate
14)]
15
16pub mod array {
17 pub use wave_runtime::memory::{DeviceBuffer, ElementType};
20
21 pub fn from_f32(data: &[f32]) -> DeviceBuffer {
23 DeviceBuffer::from_f32(data)
24 }
25
26 pub fn zeros_f32(count: usize) -> DeviceBuffer {
28 DeviceBuffer::zeros_f32(count)
29 }
30
31 pub fn from_u32(data: &[u32]) -> DeviceBuffer {
33 DeviceBuffer::from_u32(data)
34 }
35
36 pub fn zeros_u32(count: usize) -> DeviceBuffer {
38 DeviceBuffer::zeros_u32(count)
39 }
40}
41
42pub mod device {
43 pub use wave_runtime::device::{detect_gpu as detect, Device, GpuVendor};
46}
47
48pub mod kernel {
49 use wave_runtime::backend::translate_to_vendor;
52 use wave_runtime::device::{Device, GpuVendor};
53 use wave_runtime::error::RuntimeError;
54 use wave_runtime::launcher::launch_kernel;
55 use wave_runtime::memory::DeviceBuffer;
56
57 pub use wave_compiler::Language;
58
59 pub struct CompiledKernel {
61 wbin: Vec<u8>,
62 vendor_code: Option<String>,
63 }
64
65 pub fn compile(source: &str, language: Language) -> Result<CompiledKernel, RuntimeError> {
67 let wbin = wave_runtime::compiler::compile_kernel(source, language)?;
68 Ok(CompiledKernel {
69 wbin,
70 vendor_code: None,
71 })
72 }
73
74 impl CompiledKernel {
75 pub fn launch(
77 &mut self,
78 device: &Device,
79 buffers: &mut [&mut DeviceBuffer],
80 scalars: &[u32],
81 grid: [u32; 3],
82 workgroup: [u32; 3],
83 ) -> Result<(), RuntimeError> {
84 let vendor_code = if device.vendor == GpuVendor::Emulator {
85 String::new()
86 } else if let Some(code) = &self.vendor_code {
87 code.clone()
88 } else {
89 let code = translate_to_vendor(&self.wbin, device.vendor)?;
90 self.vendor_code = Some(code.clone());
91 code
92 };
93
94 launch_kernel(
95 &vendor_code,
96 &self.wbin,
97 device.vendor,
98 buffers,
99 scalars,
100 grid,
101 workgroup,
102 )
103 }
104 }
105}
106
107pub use wave_runtime::error::RuntimeError;
108pub use wave_runtime::Language;