Skip to main content

wave_gpu/
lib.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! WAVE SDK for Rust: write GPU kernels in Rust, run on any GPU.
5//!
6//! Thin wrapper around the `wave-runtime` crate, providing a convenient API
7//! for kernel compilation, device detection, array management, and kernel launch.
8
9#![allow(
10    clippy::missing_errors_doc,
11    clippy::missing_panics_doc,
12    clippy::module_name_repetitions,
13    clippy::must_use_candidate
14)]
15
16pub mod array {
17    //! Array types for kernel data.
18
19    pub use wave_runtime::memory::{DeviceBuffer, ElementType};
20
21    /// Create a `DeviceBuffer` from an `f32` slice.
22    pub fn from_f32(data: &[f32]) -> DeviceBuffer {
23        DeviceBuffer::from_f32(data)
24    }
25
26    /// Create a zero-filled `f32` buffer.
27    pub fn zeros_f32(count: usize) -> DeviceBuffer {
28        DeviceBuffer::zeros_f32(count)
29    }
30
31    /// Create a `DeviceBuffer` from a `u32` slice.
32    pub fn from_u32(data: &[u32]) -> DeviceBuffer {
33        DeviceBuffer::from_u32(data)
34    }
35
36    /// Create a zero-filled `u32` buffer.
37    pub fn zeros_u32(count: usize) -> DeviceBuffer {
38        DeviceBuffer::zeros_u32(count)
39    }
40}
41
42pub mod device {
43    //! GPU detection.
44
45    pub use wave_runtime::device::{detect_gpu as detect, Device, GpuVendor};
46}
47
48pub mod kernel {
49    //! Kernel compilation and launch.
50
51    use wave_runtime::backend::translate_to_vendor;
52    use wave_runtime::device::{Device, GpuVendor};
53    use wave_runtime::error::RuntimeError;
54    use wave_runtime::launcher::launch_kernel;
55    use wave_runtime::memory::DeviceBuffer;
56
57    pub use wave_compiler::Language;
58
59    /// A compiled WAVE kernel ready for launch.
60    pub struct CompiledKernel {
61        wbin: Vec<u8>,
62        vendor_code: Option<String>,
63    }
64
65    /// Compile kernel source to a `CompiledKernel`.
66    pub fn compile(source: &str, language: Language) -> Result<CompiledKernel, RuntimeError> {
67        let wbin = wave_runtime::compiler::compile_kernel(source, language)?;
68        Ok(CompiledKernel {
69            wbin,
70            vendor_code: None,
71        })
72    }
73
74    impl CompiledKernel {
75        /// Launch this kernel on the given device.
76        pub fn launch(
77            &mut self,
78            device: &Device,
79            buffers: &mut [&mut DeviceBuffer],
80            scalars: &[u32],
81            grid: [u32; 3],
82            workgroup: [u32; 3],
83        ) -> Result<(), RuntimeError> {
84            let vendor_code = if device.vendor == GpuVendor::Emulator {
85                String::new()
86            } else if let Some(code) = &self.vendor_code {
87                code.clone()
88            } else {
89                let code = translate_to_vendor(&self.wbin, device.vendor)?;
90                self.vendor_code = Some(code.clone());
91                code
92            };
93
94            launch_kernel(
95                &vendor_code,
96                &self.wbin,
97                device.vendor,
98                buffers,
99                scalars,
100                grid,
101                workgroup,
102            )
103        }
104    }
105}
106
107pub use wave_runtime::error::RuntimeError;
108pub use wave_runtime::Language;