use std::fmt;
use std::sync::Arc;
use wgpu::{Adapter, Backend, Limits};
#[derive(Debug, Clone)]
pub enum WgpuError {
NoAdapter,
DeviceError(String),
BufferError(String),
ShaderError(String),
}
impl fmt::Display for WgpuError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WgpuError::NoAdapter => write!(f, "No suitable WebGPU adapter found"),
WgpuError::DeviceError(msg) => write!(f, "WebGPU device error: {}", msg),
WgpuError::BufferError(msg) => write!(f, "WebGPU buffer error: {}", msg),
WgpuError::ShaderError(msg) => write!(f, "WebGPU shader error: {}", msg),
}
}
}
impl std::error::Error for WgpuError {}
#[derive(Clone)]
pub(crate) struct AdapterInfo {
name: String,
backend: Backend,
limits: Limits,
subgroups_supported: bool,
min_subgroup_size: u32,
max_subgroup_size: u32,
}
#[derive(Clone, Default)]
pub struct WgpuDevice {
pub(crate) index: usize,
info: Option<Arc<AdapterInfo>>,
}
impl WgpuDevice {
pub fn new(index: usize) -> Self {
Self { index, info: None }
}
pub(crate) fn with_info(index: usize, info: Arc<AdapterInfo>) -> Self {
Self {
index,
info: Some(info),
}
}
pub fn adapter_name(&self) -> String {
self.info
.as_ref()
.map(|i| i.name.clone())
.unwrap_or_else(|| "unknown".to_string())
}
pub fn backend(&self) -> Option<Backend> {
self.info.as_ref().map(|i| i.backend)
}
pub fn limits(&self) -> Limits {
self.info
.as_ref()
.map(|i| i.limits.clone())
.unwrap_or_default()
}
pub fn subgroups_supported(&self) -> bool {
self.info.as_ref().is_some_and(|i| i.subgroups_supported)
}
pub fn subgroup_size(&self) -> (u32, u32) {
self.info
.as_ref()
.map(|i| (i.min_subgroup_size, i.max_subgroup_size))
.unwrap_or((0, 0))
}
pub fn max_workgroup_size(&self) -> u32 {
self.limits().max_compute_workgroup_size_x
}
pub fn max_storage_buffer_size(&self) -> u64 {
self.limits().max_storage_buffer_binding_size as u64
}
}
impl super::super::Device for WgpuDevice {
fn id(&self) -> usize {
self.index
}
fn name(&self) -> String {
format!("wgpu:{}", self.index)
}
}
impl fmt::Debug for WgpuDevice {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WgpuDevice")
.field("index", &self.index)
.field("adapter", &self.adapter_name())
.field("backend", &self.backend())
.finish()
}
}
pub(crate) async fn query_adapter_info(
index: usize,
) -> Result<(Adapter, Arc<AdapterInfo>), WgpuError> {
let instance = wgpu::Instance::default();
let adapters: Vec<_> = instance.enumerate_adapters(wgpu::Backends::all()).await;
if adapters.is_empty() {
return Err(WgpuError::NoAdapter);
}
let adapter = if index < adapters.len() {
let mut adapters = adapters;
adapters.swap_remove(index)
} else {
instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.map_err(|_| WgpuError::NoAdapter)?
};
let wgpu_info = adapter.get_info();
let limits = adapter.limits();
let features = adapter.features();
let subgroups_supported = features.contains(wgpu::Features::SUBGROUP);
let (min_subgroup_size, max_subgroup_size) = if subgroups_supported {
(4, 64) } else {
(0, 0)
};
let info = Arc::new(AdapterInfo {
name: wgpu_info.name,
backend: wgpu_info.backend,
limits,
subgroups_supported,
min_subgroup_size,
max_subgroup_size,
});
Ok((adapter, info))
}
pub(crate) fn query_adapter_info_blocking(
index: usize,
) -> Result<(Adapter, Arc<AdapterInfo>), WgpuError> {
pollster::block_on(query_adapter_info(index))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::Device;
#[test]
fn test_wgpu_device_creation() {
let device = WgpuDevice::new(0);
assert_eq!(device.id(), 0);
assert_eq!(device.name(), "wgpu:0");
}
#[test]
fn test_wgpu_device_with_adapter() {
match query_adapter_info_blocking(0) {
Ok((_, info)) => {
let device = WgpuDevice::with_info(0, info);
println!("Adapter: {}", device.adapter_name());
println!("Backend: {:?}", device.backend());
println!("Max workgroup size: {}", device.max_workgroup_size());
println!("Subgroups: {}", device.subgroups_supported());
assert!(!device.adapter_name().is_empty());
}
Err(e) => {
println!("No GPU available, skipping test: {}", e);
}
}
}
}