use crate::error::{GpuError, Result};
pub struct GpuContext {
pub instance: wgpu::Instance,
pub adapter: wgpu::Adapter,
pub device: wgpu::Device,
pub queue: wgpu::Queue,
}
impl GpuContext {
pub async fn new() -> Result<Self> {
GpuContextBuilder::new().build().await
}
pub async fn new_for_surface(surface: &wgpu::Surface<'_>) -> Result<Self> {
GpuContextBuilder::new()
.compatible_surface(surface)
.build()
.await
}
#[must_use]
#[inline]
pub fn adapter_info(&self) -> wgpu::AdapterInfo {
self.adapter.get_info()
}
#[must_use]
#[inline]
pub fn limits(&self) -> wgpu::Limits {
self.device.limits()
}
#[must_use]
#[inline]
pub fn features(&self) -> wgpu::Features {
self.device.features()
}
pub fn poll_wait(&self) {
let _ = self.device.poll(wgpu::PollType::Wait {
timeout: None,
submission_index: None,
});
}
}
pub struct GpuContextBuilder<'a> {
power_preference: wgpu::PowerPreference,
features: wgpu::Features,
limits: wgpu::Limits,
compatible_surface: Option<&'a wgpu::Surface<'a>>,
device_lost_callback: Option<Box<dyn Fn(wgpu::DeviceLostReason, String) + Send + 'static>>,
}
impl Default for GpuContextBuilder<'_> {
fn default() -> Self {
Self::new()
}
}
impl<'a> GpuContextBuilder<'a> {
#[must_use]
pub fn new() -> Self {
Self {
power_preference: wgpu::PowerPreference::HighPerformance,
features: wgpu::Features::empty(),
limits: wgpu::Limits::default(),
compatible_surface: None,
device_lost_callback: None,
}
}
#[must_use]
pub fn power_preference(mut self, pref: wgpu::PowerPreference) -> Self {
self.power_preference = pref;
self
}
#[must_use]
pub fn features(mut self, features: wgpu::Features) -> Self {
self.features = features;
self
}
#[must_use]
pub fn limits(mut self, limits: wgpu::Limits) -> Self {
self.limits = limits;
self
}
#[must_use]
pub fn compatible_surface(mut self, surface: &'a wgpu::Surface<'a>) -> Self {
self.compatible_surface = Some(surface);
self
}
pub fn device_lost_callback(
mut self,
callback: impl Fn(wgpu::DeviceLostReason, String) + Send + 'static,
) -> Self {
self.device_lost_callback = Some(Box::new(callback));
self
}
pub async fn build(self) -> Result<GpuContext> {
let mut desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
desc.backends = wgpu::Backends::all();
let instance = wgpu::Instance::new(desc);
tracing::debug!(
?self.power_preference,
"requesting GPU adapter"
);
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: self.power_preference,
compatible_surface: self.compatible_surface,
force_fallback_adapter: false,
})
.await
.map_err(|_| {
tracing::error!("no suitable GPU adapter found");
GpuError::AdapterNotFound
})?;
let device_desc = wgpu::DeviceDescriptor {
label: Some("mabda_device"),
required_features: self.features,
required_limits: self.limits,
..Default::default()
};
let (device, queue) =
adapter
.request_device(&device_desc)
.await
.map_err(|e: wgpu::RequestDeviceError| {
tracing::error!("GPU device request failed: {e}");
GpuError::DeviceRequest(e)
})?;
if let Some(callback) = self.device_lost_callback {
device.set_device_lost_callback(callback);
}
tracing::info!(
adapter = adapter.get_info().name,
backend = ?adapter.get_info().backend,
"GPU context initialized"
);
Ok(GpuContext {
instance,
adapter,
device,
queue,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn gpu_context_types() {
let _size = std::mem::size_of::<GpuContext>();
}
#[test]
fn builder_defaults() {
let builder = GpuContextBuilder::new();
assert_eq!(
builder.power_preference,
wgpu::PowerPreference::HighPerformance
);
assert_eq!(builder.features, wgpu::Features::empty());
assert!(builder.compatible_surface.is_none());
assert!(builder.device_lost_callback.is_none());
}
#[test]
fn builder_chaining() {
let builder = GpuContextBuilder::new()
.power_preference(wgpu::PowerPreference::LowPower)
.features(wgpu::Features::TIMESTAMP_QUERY);
assert_eq!(builder.power_preference, wgpu::PowerPreference::LowPower);
assert!(builder.features.contains(wgpu::Features::TIMESTAMP_QUERY));
}
#[test]
fn builder_default_trait() {
let builder = GpuContextBuilder::default();
assert_eq!(
builder.power_preference,
wgpu::PowerPreference::HighPerformance
);
}
#[test]
fn headless_gpu_context() {
let result = pollster::block_on(GpuContext::new());
if let Ok(ctx) = result {
let info = ctx.adapter_info();
assert!(!info.name.is_empty());
let limits = ctx.limits();
assert!(limits.max_texture_dimension_2d > 0);
let _features = ctx.features();
ctx.poll_wait();
}
}
#[test]
fn builder_with_custom_limits() {
let result = pollster::block_on(
GpuContextBuilder::new()
.limits(wgpu::Limits::default())
.build(),
);
if let Ok(ctx) = result {
assert!(ctx.limits().max_texture_dimension_2d > 0);
}
}
#[test]
fn builder_with_device_lost_callback() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let called = Arc::new(AtomicBool::new(false));
let called_clone = called.clone();
let result = pollster::block_on(
GpuContextBuilder::new()
.device_lost_callback(move |_reason, _msg| {
called_clone.store(true, Ordering::SeqCst);
})
.build(),
);
if let Ok(ctx) = result {
assert!(ctx.limits().max_texture_dimension_2d > 0);
}
}
#[test]
fn builder_low_power() {
let result = pollster::block_on(
GpuContextBuilder::new()
.power_preference(wgpu::PowerPreference::LowPower)
.build(),
);
if let Ok(ctx) = result {
let _info = ctx.adapter_info();
}
}
}