use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::task::{Context, Poll, Wake, Waker};
use std::thread::{self, Thread};
use vyre::error::{Error, Result};
pub(crate) static CACHED_DEVICE_PTR: AtomicUsize = AtomicUsize::new(0);
pub(crate) struct CachedGpu {
pub(crate) pair: (wgpu::Device, wgpu::Queue),
}
#[inline]
pub fn cached_device() -> Result<&'static (wgpu::Device, wgpu::Queue)> {
let cached = cached_gpu()?;
CACHED_DEVICE_PTR.store(
std::ptr::from_ref::<wgpu::Device>(&cached.pair.0).addr(),
Ordering::Release,
);
Ok(&cached.pair)
}
#[inline]
pub(crate) fn cached_gpu() -> Result<&'static CachedGpu> {
static GPU: OnceLock<Arc<CachedGpu>> = OnceLock::new();
static INIT_LOCK: Mutex<()> = Mutex::new(());
if let Some(gpu) = GPU.get() {
return Ok(gpu.as_ref());
}
let _guard = INIT_LOCK.lock().map_err(|source| Error::Gpu {
message: format!(
"cached GPU init mutex poisoned: {source}. Fix: restart the process or avoid panicking while initializing the GPU cache."
),
})?;
if let Some(gpu) = GPU.get() {
return Ok(gpu.as_ref());
}
let pair = init_device()?;
let _ = GPU.set(Arc::new(CachedGpu { pair }));
let gpu = GPU.get().ok_or_else(|| Error::Gpu {
message: "cached GPU initialization failed after successful device creation. Fix: report this vyre runtime bug with the active platform and wgpu backend.".to_string(),
})?;
Ok(gpu.as_ref())
}
#[inline]
pub(crate) fn is_cached_device(device: &wgpu::Device) -> bool {
let device_ptr = std::ptr::from_ref(device).addr();
CACHED_DEVICE_PTR.load(Ordering::Acquire) == device_ptr
}
#[inline]
pub fn init_device() -> Result<(wgpu::Device, wgpu::Queue)> {
wait_for_gpu(acquire_gpu())
}
#[inline]
pub async fn acquire_gpu() -> Result<(wgpu::Device, wgpu::Queue)> {
let instance = wgpu::Instance::default();
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions::default())
.await
.ok_or_else(|| Error::Gpu {
message: "failed to acquire adapter. Fix: install a compatible GPU driver, expose a wgpu-supported adapter, or run on a host with GPU access.".to_string(),
})?;
let adapter_info = adapter.get_info();
if matches!(
adapter_info.device_type,
wgpu::DeviceType::Cpu | wgpu::DeviceType::Other
) {
return Err(Error::Gpu {
message: format!(
"adapter '{}' has device type {:?}, which is not a real GPU execution target. Fix: expose a discrete, integrated, or virtual GPU adapter before running vyre.",
adapter_info.name, adapter_info.device_type
),
});
}
let mut features = wgpu::Features::empty();
if adapter.features().contains(wgpu::Features::TIMESTAMP_QUERY) {
features |= wgpu::Features::TIMESTAMP_QUERY;
}
let adapter_limits = adapter.limits();
adapter.request_device(
&wgpu::DeviceDescriptor {
label: Some("vyre device"),
required_features: features,
required_limits: wgpu::Limits {
max_storage_buffers_per_shader_stage:
adapter_limits.max_storage_buffers_per_shader_stage,
..wgpu::Limits::default()
},
memory_hints: wgpu::MemoryHints::default(),
},
None,
)
.await
.map_err(|error| Error::Gpu {
message: format!("failed to acquire device: {error}. Fix: check requested wgpu limits/features against the adapter and update the GPU driver if limits are unexpectedly low."),
})
}
struct ThreadWaker(Thread);
impl Wake for ThreadWaker {
fn wake(self: Arc<Self>) {
self.0.unpark();
}
fn wake_by_ref(self: &Arc<Self>) {
self.0.unpark();
}
}
fn wait_for_gpu<T>(future: impl Future<Output = T>) -> T {
let waker = Waker::from(Arc::new(ThreadWaker(thread::current())));
let mut context = Context::from_waker(&waker);
let mut future = Box::pin(future);
loop {
match Pin::as_mut(&mut future).poll(&mut context) {
Poll::Ready(value) => return value,
Poll::Pending => thread::park(),
}
}
}