use wgpu::{Adapter, Device, DeviceDescriptor, Features, Instance, Limits, Queue};
pub struct GpuDevice {
pub device: Device,
pub queue: Queue,
pub adapter: Adapter,
pub capabilities: GpuCapabilities,
}
#[derive(Debug, Clone)]
pub struct GpuCapabilities {
pub max_buffer_size: u64,
pub max_storage_buffer_binding_size: u32,
pub max_compute_workgroup_size_x: u32,
pub max_compute_invocations_per_workgroup: u32,
pub adapter_name: String,
pub backend: String,
}
impl GpuDevice {
pub async fn new() -> Result<Self, String> {
let instance = Instance::new(wgpu::InstanceDescriptor {
backends: wgpu::Backends::all(),
..Default::default()
});
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.ok_or_else(|| "No WebGPU adapter found".to_string())?;
let info = adapter.get_info();
let limits = adapter.limits();
#[cfg(target_arch = "wasm32")]
{
web_sys::console::log_1(&format!("WebGPU adapter found: {}", info.name).into());
web_sys::console::log_1(&format!(" Backend: {:?}", info.backend).into());
}
#[cfg(not(target_arch = "wasm32"))]
{
println!("WebGPU adapter found:");
println!(" Name: {}", info.name);
println!(" Backend: {:?}", info.backend);
println!(" Vendor: 0x{:X}", info.vendor);
println!(" Device: 0x{:X}", info.device);
println!(" Type: {:?}", info.device_type);
}
let (device, queue) = adapter
.request_device(
&DeviceDescriptor {
label: Some("Ligerito GPU Device"),
required_features: Features::empty(),
required_limits: Limits::default(),
},
None,
)
.await
.map_err(|e| format!("Failed to create device: {}", e))?;
let capabilities = GpuCapabilities {
max_buffer_size: limits.max_buffer_size,
max_storage_buffer_binding_size: limits.max_storage_buffer_binding_size,
max_compute_workgroup_size_x: limits.max_compute_workgroup_size_x,
max_compute_invocations_per_workgroup: limits.max_compute_invocations_per_workgroup,
adapter_name: info.name.clone(),
backend: format!("{:?}", info.backend),
};
#[cfg(target_arch = "wasm32")]
{
web_sys::console::log_1(
&format!(
"WebGPU device initialized: {} MB max buffer",
capabilities.max_buffer_size / (1024 * 1024)
)
.into(),
);
}
#[cfg(not(target_arch = "wasm32"))]
{
println!("WebGPU device initialized:");
println!(
" Max buffer size: {} MB",
capabilities.max_buffer_size / (1024 * 1024)
);
println!(
" Max workgroup size: {}",
capabilities.max_compute_workgroup_size_x
);
}
Ok(Self {
device,
queue,
adapter,
capabilities,
})
}
pub fn can_handle_buffer(&self, size: u64) -> bool {
size <= self.capabilities.max_buffer_size
}
pub fn optimal_workgroup_size(&self, problem_size: u32) -> u32 {
let max_size = self.capabilities.max_compute_workgroup_size_x;
for size in [256, 128, 64, 32, 16, 8, 4, 2, 1].iter() {
if *size <= max_size && problem_size >= *size {
return *size;
}
}
1
}
}
pub async fn is_available() -> bool {
let instance = Instance::new(wgpu::InstanceDescriptor {
backends: wgpu::Backends::all(),
..Default::default()
});
instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.is_some()
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_device_initialization() {
if !is_available().await {
eprintln!("WebGPU not available, skipping test");
return;
}
let device = GpuDevice::new().await;
assert!(device.is_ok());
if let Ok(dev) = device {
println!("GPU: {}", dev.capabilities.adapter_name);
println!("Backend: {}", dev.capabilities.backend);
assert!(dev.capabilities.max_buffer_size > 0);
}
}
#[test]
fn test_optimal_workgroup_size() {
let caps = GpuCapabilities {
max_buffer_size: 1 << 30,
max_storage_buffer_binding_size: 1 << 27,
max_compute_workgroup_size_x: 256,
max_compute_invocations_per_workgroup: 256,
adapter_name: "Test GPU".to_string(),
backend: "WebGPU".to_string(),
};
let device = GpuDevice {
device: unsafe { std::mem::zeroed() }, queue: unsafe { std::mem::zeroed() },
adapter: unsafe { std::mem::zeroed() },
capabilities: caps.clone(),
};
assert_eq!(device.optimal_workgroup_size(1024), 256);
assert_eq!(device.optimal_workgroup_size(128), 128);
assert_eq!(device.optimal_workgroup_size(10), 8);
}
}