vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use std::sync::{mpsc, LazyLock};

use bytemuck::{cast_slice, Pod};
use wgpu::util::DeviceExt;

static GPU: LazyLock<(wgpu::Device, wgpu::Queue)> = LazyLock::new(init_required_gpu);

pub fn required_gpu() -> &'static (wgpu::Device, wgpu::Queue) {
    &GPU
}

pub fn compile(
    device: &wgpu::Device,
    label: &'static str,
    source: &'static str,
    entry_point: &'static str,
) -> wgpu::ComputePipeline {
    let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
        label: Some(label),
        source: wgpu::ShaderSource::Wgsl(source.into()),
    });
    device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
        label: Some(label),
        layout: None,
        module: &module,
        entry_point: Some(entry_point),
        compilation_options: wgpu::PipelineCompilationOptions::default(),
        cache: None,
    })
}

pub fn storage_init<T: Pod>(
    device: &wgpu::Device,
    label: &'static str,
    data: &[T],
    usage: wgpu::BufferUsages,
) -> wgpu::Buffer {
    device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
        label: Some(label),
        contents: cast_slice(data),
        usage,
    })
}

pub fn storage_empty(device: &wgpu::Device, label: &'static str, byte_len: u64) -> wgpu::Buffer {
    device.create_buffer(&wgpu::BufferDescriptor {
        label: Some(label),
        size: byte_len,
        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
        mapped_at_creation: false,
    })
}

pub fn readback_buffer(device: &wgpu::Device, label: &'static str, byte_len: u64) -> wgpu::Buffer {
    device.create_buffer(&wgpu::BufferDescriptor {
        label: Some(label),
        size: byte_len,
        usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
        mapped_at_creation: false,
    })
}

pub fn read_u32(device: &wgpu::Device, buffer: &wgpu::Buffer, word_len: usize) -> Vec<u32> {
    let byte_len = byte_len::<u32>(word_len);
    let slice = buffer.slice(0..byte_len);
    let (sender, receiver) = mpsc::channel();
    slice.map_async(wgpu::MapMode::Read, move |result| {
        let _ = sender.send(result);
    });
    match device.poll(wgpu::Maintain::Wait) {
        wgpu::MaintainResult::Ok | wgpu::MaintainResult::SubmissionQueueEmpty => {}
    }
    receiver
        .recv()
        .expect("readback callback must complete")
        .expect("readback buffer must map");
    let mapped = slice.get_mapped_range();
    let bytes = mapped.to_vec();
    drop(mapped);
    buffer.unmap();
    cast_slice::<u8, u32>(&bytes).to_vec()
}

pub fn byte_len<T>(count: usize) -> u64 {
    let bytes = count
        .checked_mul(std::mem::size_of::<T>())
        .expect("buffer byte length must not overflow usize");
    u64::try_from(bytes).expect("buffer byte length must fit u64")
}

fn init_required_gpu() -> (wgpu::Device, wgpu::Queue) {
    let instance = wgpu::Instance::default();
    let adapters = instance.enumerate_adapters(wgpu::Backends::all());
    if adapters.is_empty() {
        panic!("GPU required: wgpu::Instance::enumerate_adapters returned no adapters. Fix: expose the RTX 5090 to the process and verify the NVIDIA driver.");
    }
    let mut listed = Vec::new();
    let mut selected = None;
    for adapter in adapters {
        let info = adapter.get_info();
        listed.push(format!(
            "{} ({:?}, {:?})",
            info.name, info.backend, info.device_type
        ));
        if info.name.contains("RTX 5090") && info.device_type != wgpu::DeviceType::Cpu {
            selected = Some(adapter);
            break;
        }
    }
    let adapter = selected.unwrap_or_else(|| {
        panic!(
            "GPU required: RTX 5090 adapter not found via enumerate_adapters. Found: {}. Fix: expose the NVIDIA RTX 5090 to wgpu.",
            listed.join(", ")
        )
    });
    let info = adapter.get_info();
    eprintln!(
        "workgroup parity GPU: {} ({:?}, {:?})",
        info.name, info.backend, info.device_type
    );
    let limits = adapter.limits();
    pollster::block_on(adapter.request_device(
        &wgpu::DeviceDescriptor {
            label: Some("vyre workgroup parity device"),
            required_features: wgpu::Features::empty(),
            required_limits: wgpu::Limits {
                max_storage_buffers_per_shader_stage: limits.max_storage_buffers_per_shader_stage,
                max_compute_workgroup_storage_size: limits.max_compute_workgroup_storage_size,
                ..wgpu::Limits::default()
            },
            memory_hints: wgpu::MemoryHints::default(),
        },
        None,
    ))
    .unwrap_or_else(|error| {
        panic!("GPU required: failed to create RTX 5090 wgpu device: {error}. Fix: update the NVIDIA driver or lower unsupported requested limits.")
    })
}