use std::sync::{Arc, Mutex, MutexGuard, OnceLock};
pub(super) struct GpuContext {
pub(super) device: wgpu::Device,
pub(super) queue: wgpu::Queue,
pub(super) adapter_features: wgpu::Features,
pub(super) adapter_downlevel: wgpu::DownlevelCapabilities,
}
static GPU: OnceLock<Mutex<Option<Arc<GpuContext>>>> = OnceLock::new();
fn gpu_cache() -> &'static Mutex<Option<Arc<GpuContext>>> {
GPU.get_or_init(|| Mutex::new(None))
}
fn lock_gpu_cache() -> MutexGuard<'static, Option<Arc<GpuContext>>> {
match gpu_cache().lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
}
}
fn clear_gpu_cache(reason: wgpu::DeviceLostReason, message: String) {
tracing::warn!(
"vyre-conform GPU device lost: {:?}: {}. Clearing cached device.",
reason,
message
);
*lock_gpu_cache() = None;
}
pub(super) fn get_gpu() -> Option<Arc<GpuContext>> {
let mut cache = lock_gpu_cache();
if let Some(ctx) = cache.as_ref() {
return Some(Arc::clone(ctx));
}
let ctx = Arc::new(create_gpu_context()?);
*cache = Some(Arc::clone(&ctx));
Some(ctx)
}
fn create_gpu_context() -> Option<GpuContext> {
let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
backends: wgpu::Backends::all(),
..Default::default()
});
let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
..Default::default()
}))?;
let features = adapter.features();
let downlevel = adapter.get_downlevel_capabilities();
let (device, queue) = pollster::block_on(adapter.request_device(
&wgpu::DeviceDescriptor {
label: Some("vyre-conform"),
..Default::default()
},
None,
))
.ok()?;
device.set_device_lost_callback(clear_gpu_cache);
let info = adapter.get_info();
assert!(
!matches!(info.device_type, wgpu::DeviceType::Cpu | wgpu::DeviceType::Other),
"Fix: adapter {} is not a GPU (type {:?}). vyre-conform requires a discrete or integrated GPU.",
info.name,
info.device_type
);
tracing::info!(
"vyre-conform using GPU adapter: {} ({:?})",
info.name,
info.device_type
);
assert!(
downlevel.flags.contains(wgpu::DownlevelFlags::COMPUTE_SHADERS),
"Fix: adapter {} does not support compute shaders. vyre-conform requires baseline compute capability.",
info.name
);
Some(GpuContext {
device,
queue,
adapter_features: features,
adapter_downlevel: downlevel,
})
}