vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Lazy GPU context creation and caching.
//!
//! [`GpuContext`] holds the wgpu `Device` and `Queue`. It is created once on
//! first use and stored in a global `OnceLock` so that every dispatch shares
//! the same adapter, avoiding expensive re-initialization between tests.

use std::sync::{Arc, Mutex, MutexGuard, OnceLock};

/// Shared GPU device + queue. Created once, reused across all dispatches.
pub(super) struct GpuContext {
    pub(super) device: wgpu::Device,
    pub(super) queue: wgpu::Queue,
    pub(super) adapter_features: wgpu::Features,
    pub(super) adapter_downlevel: wgpu::DownlevelCapabilities,
}

static GPU: OnceLock<Mutex<Option<Arc<GpuContext>>>> = OnceLock::new();
fn gpu_cache() -> &'static Mutex<Option<Arc<GpuContext>>> {
    GPU.get_or_init(|| Mutex::new(None))
}

fn lock_gpu_cache() -> MutexGuard<'static, Option<Arc<GpuContext>>> {
    match gpu_cache().lock() {
        Ok(guard) => guard,
        Err(poisoned) => poisoned.into_inner(),
    }
}

fn clear_gpu_cache(reason: wgpu::DeviceLostReason, message: String) {
    tracing::warn!(
        "vyre-conform GPU device lost: {:?}: {}. Clearing cached device.",
        reason,
        message
    );
    *lock_gpu_cache() = None;
}

pub(super) fn get_gpu() -> Option<Arc<GpuContext>> {
    let mut cache = lock_gpu_cache();
    if let Some(ctx) = cache.as_ref() {
        return Some(Arc::clone(ctx));
    }

    let ctx = Arc::new(create_gpu_context()?);
    *cache = Some(Arc::clone(&ctx));
    Some(ctx)
}

fn create_gpu_context() -> Option<GpuContext> {
    let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
        backends: wgpu::Backends::all(),
        ..Default::default()
    });
    let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
        power_preference: wgpu::PowerPreference::HighPerformance,
        ..Default::default()
    }))?;
    let features = adapter.features();
    let downlevel = adapter.get_downlevel_capabilities();
    let (device, queue) = pollster::block_on(adapter.request_device(
        &wgpu::DeviceDescriptor {
            label: Some("vyre-conform"),
            ..Default::default()
        },
        None,
    ))
    .ok()?;
    device.set_device_lost_callback(clear_gpu_cache);

    let info = adapter.get_info();
    assert!(
        !matches!(info.device_type, wgpu::DeviceType::Cpu | wgpu::DeviceType::Other),
        "Fix: adapter {} is not a GPU (type {:?}). vyre-conform requires a discrete or integrated GPU.",
        info.name,
        info.device_type
    );
    tracing::info!(
        "vyre-conform using GPU adapter: {} ({:?})",
        info.name,
        info.device_type
    );
    assert!(
        downlevel.flags.contains(wgpu::DownlevelFlags::COMPUTE_SHADERS),
        "Fix: adapter {} does not support compute shaders. vyre-conform requires baseline compute capability.",
        info.name
    );

    Some(GpuContext {
        device,
        queue,
        adapter_features: features,
        adapter_downlevel: downlevel,
    })
}