vyre-wgpu 0.1.0

wgpu backend for vyre IR — implements VyreBackend, owns GPU runtime, buffer pool, pipeline cache
Documentation
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::task::{Context, Poll, Wake, Waker};
use std::thread::{self, Thread};
use vyre::error::{Error, Result};

pub(crate) static CACHED_DEVICE_PTR: AtomicUsize = AtomicUsize::new(0);

pub(crate) struct CachedGpu {
    pub(crate) pair: (wgpu::Device, wgpu::Queue),
}

/// Return the cached device/queue pair for repeated scans.
///
/// # Errors
///
/// Returns an error if the GPU adapter or device cannot be initialized.
#[inline]
pub fn cached_device() -> Result<&'static (wgpu::Device, wgpu::Queue)> {
    let cached = cached_gpu()?;
    CACHED_DEVICE_PTR.store(
        std::ptr::from_ref::<wgpu::Device>(&cached.pair.0).addr(),
        Ordering::Release,
    );
    Ok(&cached.pair)
}

/// Return the cached runtime device and its reusable buffer pool.
#[inline]
pub(crate) fn cached_gpu() -> Result<&'static CachedGpu> {
    static GPU: OnceLock<Arc<CachedGpu>> = OnceLock::new();
    static INIT_LOCK: Mutex<()> = Mutex::new(());
    if let Some(gpu) = GPU.get() {
        return Ok(gpu.as_ref());
    }

    let _guard = INIT_LOCK.lock().map_err(|source| Error::Gpu {
        message: format!(
            "cached GPU init mutex poisoned: {source}. Fix: restart the process or avoid panicking while initializing the GPU cache."
        ),
    })?;
    if let Some(gpu) = GPU.get() {
        return Ok(gpu.as_ref());
    }

    let pair = init_device()?;
    let _ = GPU.set(Arc::new(CachedGpu { pair }));
    let gpu = GPU.get().ok_or_else(|| Error::Gpu {
        message: "cached GPU initialization failed after successful device creation. Fix: report this vyre runtime bug with the active platform and wgpu backend.".to_string(),
    })?;
    Ok(gpu.as_ref())
}

/// Return true when `device` is the runtime singleton device.
#[inline]
pub(crate) fn is_cached_device(device: &wgpu::Device) -> bool {
    let device_ptr = std::ptr::from_ref(device).addr();
    CACHED_DEVICE_PTR.load(Ordering::Acquire) == device_ptr
}

/// Initialize a new GPU device and queue.
///
/// # Errors
///
/// Returns an actionable GPU error if no compatible adapter is available, if
/// the selected adapter is CPU-backed, or if device creation fails.
#[inline]
pub fn init_device() -> Result<(wgpu::Device, wgpu::Queue)> {
    wait_for_gpu(acquire_gpu())
}

/// Asynchronously initialize a new GPU device and queue.
///
/// # Errors
///
/// Returns an actionable GPU error if no compatible adapter is available, if
/// the selected adapter is CPU-backed, or if device creation fails.
#[inline]
pub async fn acquire_gpu() -> Result<(wgpu::Device, wgpu::Queue)> {
    let instance = wgpu::Instance::default();
    let adapter = instance
        .request_adapter(&wgpu::RequestAdapterOptions::default())
        .await
        .ok_or_else(|| Error::Gpu {
            message: "failed to acquire adapter. Fix: install a compatible GPU driver, expose a wgpu-supported adapter, or run on a host with GPU access.".to_string(),
        })?;
    let adapter_info = adapter.get_info();
    if matches!(
        adapter_info.device_type,
        wgpu::DeviceType::Cpu | wgpu::DeviceType::Other
    ) {
        return Err(Error::Gpu {
            message: format!(
                "adapter '{}' has device type {:?}, which is not a real GPU execution target. Fix: expose a discrete, integrated, or virtual GPU adapter before running vyre.",
                adapter_info.name, adapter_info.device_type
            ),
        });
    }

    let mut features = wgpu::Features::empty();
    if adapter.features().contains(wgpu::Features::TIMESTAMP_QUERY) {
        features |= wgpu::Features::TIMESTAMP_QUERY;
    }

    let adapter_limits = adapter.limits();
    adapter.request_device(
        &wgpu::DeviceDescriptor {
            label: Some("vyre device"),
            required_features: features,
            required_limits: wgpu::Limits {
                max_storage_buffers_per_shader_stage:
                    adapter_limits.max_storage_buffers_per_shader_stage,
                ..wgpu::Limits::default()
            },
            memory_hints: wgpu::MemoryHints::default(),
        },
        None,
    )
    .await
    .map_err(|error| Error::Gpu {
        message: format!("failed to acquire device: {error}. Fix: check requested wgpu limits/features against the adapter and update the GPU driver if limits are unexpectedly low."),
    })
}

struct ThreadWaker(Thread);

impl Wake for ThreadWaker {
    fn wake(self: Arc<Self>) {
        self.0.unpark();
    }

    fn wake_by_ref(self: &Arc<Self>) {
        self.0.unpark();
    }
}

fn wait_for_gpu<T>(future: impl Future<Output = T>) -> T {
    let waker = Waker::from(Arc::new(ThreadWaker(thread::current())));
    let mut context = Context::from_waker(&waker);
    let mut future = Box::pin(future);
    loop {
        match Pin::as_mut(&mut future).poll(&mut context) {
            Poll::Ready(value) => return value,
            Poll::Pending => thread::park(),
        }
    }
}