#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum WebGpuBackend {
WebGpu,
#[default]
Cpu,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct WebGpuConfig {
pub prefer_gpu: bool,
pub fallback_to_cpu: bool,
pub tile_size: usize,
pub max_buffer_size: usize,
pub workgroup_size_x: u32,
pub workgroup_size_y: u32,
pub workgroup_size_z: u32,
pub enable_validation: bool,
}
impl Default for WebGpuConfig {
fn default() -> Self {
Self {
prefer_gpu: true,
fallback_to_cpu: true,
tile_size: 16,
max_buffer_size: 256 * 1024 * 1024, workgroup_size_x: 16,
workgroup_size_y: 16,
workgroup_size_z: 1,
enable_validation: true,
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum GpuBufferUsage {
#[default]
Storage,
Uniform,
Staging,
Vertex,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct GpuBufferDescriptor {
pub size: usize,
pub usage: GpuBufferUsage,
pub label: Option<String>,
}
impl Default for GpuBufferDescriptor {
fn default() -> Self {
Self {
size: 0,
usage: GpuBufferUsage::Storage,
label: None,
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct ComputePipelineDescriptor {
pub shader_source: String,
pub entry_point: String,
pub bind_group_count: u32,
}
impl Default for ComputePipelineDescriptor {
fn default() -> Self {
Self {
shader_source: String::new(),
entry_point: "main".to_string(),
bind_group_count: 1,
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub enum GpuError {
DeviceNotAvailable,
BufferTooLarge {
requested: usize,
limit: usize,
},
ShaderCompile(String),
Execution(String),
}
impl std::fmt::Display for GpuError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::DeviceNotAvailable => write!(f, "WebGPU device not available"),
Self::BufferTooLarge { requested, limit } => write!(
f,
"buffer size {requested} bytes exceeds limit {limit} bytes"
),
Self::ShaderCompile(msg) => write!(f, "shader compilation error: {msg}"),
Self::Execution(msg) => write!(f, "GPU execution error: {msg}"),
}
}
}
impl std::error::Error for GpuError {}
pub type WebGpuResult<T> = Result<T, GpuError>;
#[derive(Debug, Clone)]
pub struct GpuBuffer {
pub data: Vec<f32>,
pub size: usize,
pub backend: WebGpuBackend,
pub usage: GpuBufferUsage,
}
impl GpuBuffer {
pub fn new(data: Vec<f32>, backend: WebGpuBackend) -> Self {
let size = data.len();
Self {
data,
size,
backend,
usage: GpuBufferUsage::Storage,
}
}
pub fn with_data(data: Vec<f32>, usage: GpuBufferUsage) -> Self {
let size = data.len();
Self {
data,
size,
backend: WebGpuBackend::Cpu,
usage,
}
}
pub fn zeros(size: usize, usage: GpuBufferUsage) -> Self {
Self {
data: vec![0.0_f32; size],
size,
backend: WebGpuBackend::Cpu,
usage,
}
}
pub fn as_slice(&self) -> &[f32] {
&self.data
}
pub fn as_mut_slice(&mut self) -> &mut [f32] {
&mut self.data
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_webgpu_config_default() {
let cfg = WebGpuConfig::default();
assert!(cfg.prefer_gpu);
assert!(cfg.fallback_to_cpu);
assert_eq!(cfg.tile_size, 16);
assert_eq!(cfg.workgroup_size_x, 16);
assert_eq!(cfg.workgroup_size_y, 16);
assert_eq!(cfg.workgroup_size_z, 1);
assert!(cfg.enable_validation);
}
#[test]
fn test_gpu_buffer_zeros() {
let buf = GpuBuffer::zeros(8, GpuBufferUsage::Storage);
assert_eq!(buf.size, 8);
assert!(buf.data.iter().all(|&v| v == 0.0));
}
#[test]
fn test_webgpu_backend_default() {
let b = WebGpuBackend::default();
assert_eq!(b, WebGpuBackend::Cpu);
}
#[test]
fn test_gpu_buffer_usage_default() {
let u = GpuBufferUsage::default();
assert_eq!(u, GpuBufferUsage::Storage);
}
#[test]
fn test_gpu_buffer_descriptor_default() {
let d = GpuBufferDescriptor::default();
assert_eq!(d.size, 0);
assert_eq!(d.usage, GpuBufferUsage::Storage);
assert!(d.label.is_none());
}
#[test]
fn test_compute_pipeline_descriptor_default() {
let d = ComputePipelineDescriptor::default();
assert_eq!(d.entry_point, "main");
assert_eq!(d.bind_group_count, 1);
assert!(d.shader_source.is_empty());
}
#[test]
fn test_gpu_error_display_device_not_available() {
let e = GpuError::DeviceNotAvailable;
let s = e.to_string();
assert!(s.contains("not available"));
}
#[test]
fn test_gpu_error_display_buffer_too_large() {
let e = GpuError::BufferTooLarge {
requested: 1024,
limit: 512,
};
let s = e.to_string();
assert!(s.contains("1024"));
assert!(s.contains("512"));
}
#[test]
fn test_gpu_error_display_shader_compile() {
let e = GpuError::ShaderCompile("unexpected token".to_string());
let s = e.to_string();
assert!(s.contains("unexpected token"));
}
#[test]
fn test_gpu_error_display_execution() {
let e = GpuError::Execution("dispatch failed".to_string());
let s = e.to_string();
assert!(s.contains("dispatch failed"));
}
#[test]
fn test_gpu_buffer_with_data() {
let data = vec![1.0_f32, 2.0, 3.0];
let buf = GpuBuffer::with_data(data.clone(), GpuBufferUsage::Staging);
assert_eq!(buf.usage, GpuBufferUsage::Staging);
assert_eq!(buf.data, data);
assert_eq!(buf.size, 3);
}
}