use std::sync::{mpsc, LazyLock};
use bytemuck::{cast_slice, Pod};
use wgpu::util::DeviceExt;
static GPU: LazyLock<(wgpu::Device, wgpu::Queue)> = LazyLock::new(init_required_gpu);
pub fn required_gpu() -> &'static (wgpu::Device, wgpu::Queue) {
&GPU
}
pub fn compile(
device: &wgpu::Device,
label: &'static str,
source: &'static str,
entry_point: &'static str,
) -> wgpu::ComputePipeline {
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(label),
source: wgpu::ShaderSource::Wgsl(source.into()),
});
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(label),
layout: None,
module: &module,
entry_point: Some(entry_point),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
})
}
pub fn storage_init<T: Pod>(
device: &wgpu::Device,
label: &'static str,
data: &[T],
usage: wgpu::BufferUsages,
) -> wgpu::Buffer {
device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some(label),
contents: cast_slice(data),
usage,
})
}
pub fn storage_empty(device: &wgpu::Device, label: &'static str, byte_len: u64) -> wgpu::Buffer {
device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: byte_len,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
}
pub fn readback_buffer(device: &wgpu::Device, label: &'static str, byte_len: u64) -> wgpu::Buffer {
device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: byte_len,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
})
}
pub fn read_u32(device: &wgpu::Device, buffer: &wgpu::Buffer, word_len: usize) -> Vec<u32> {
let byte_len = byte_len::<u32>(word_len);
let slice = buffer.slice(0..byte_len);
let (sender, receiver) = mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = sender.send(result);
});
match device.poll(wgpu::Maintain::Wait) {
wgpu::MaintainResult::Ok | wgpu::MaintainResult::SubmissionQueueEmpty => {}
}
receiver
.recv()
.expect("readback callback must complete")
.expect("readback buffer must map");
let mapped = slice.get_mapped_range();
let bytes = mapped.to_vec();
drop(mapped);
buffer.unmap();
cast_slice::<u8, u32>(&bytes).to_vec()
}
pub fn byte_len<T>(count: usize) -> u64 {
let bytes = count
.checked_mul(std::mem::size_of::<T>())
.expect("buffer byte length must not overflow usize");
u64::try_from(bytes).expect("buffer byte length must fit u64")
}
fn init_required_gpu() -> (wgpu::Device, wgpu::Queue) {
let instance = wgpu::Instance::default();
let adapters = instance.enumerate_adapters(wgpu::Backends::all());
if adapters.is_empty() {
panic!("GPU required: wgpu::Instance::enumerate_adapters returned no adapters. Fix: expose the RTX 5090 to the process and verify the NVIDIA driver.");
}
let mut listed = Vec::new();
let mut selected = None;
for adapter in adapters {
let info = adapter.get_info();
listed.push(format!(
"{} ({:?}, {:?})",
info.name, info.backend, info.device_type
));
if info.name.contains("RTX 5090") && info.device_type != wgpu::DeviceType::Cpu {
selected = Some(adapter);
break;
}
}
let adapter = selected.unwrap_or_else(|| {
panic!(
"GPU required: RTX 5090 adapter not found via enumerate_adapters. Found: {}. Fix: expose the NVIDIA RTX 5090 to wgpu.",
listed.join(", ")
)
});
let info = adapter.get_info();
eprintln!(
"workgroup parity GPU: {} ({:?}, {:?})",
info.name, info.backend, info.device_type
);
let limits = adapter.limits();
pollster::block_on(adapter.request_device(
&wgpu::DeviceDescriptor {
label: Some("vyre workgroup parity device"),
required_features: wgpu::Features::empty(),
required_limits: wgpu::Limits {
max_storage_buffers_per_shader_stage: limits.max_storage_buffers_per_shader_stage,
max_compute_workgroup_storage_size: limits.max_compute_workgroup_storage_size,
..wgpu::Limits::default()
},
memory_hints: wgpu::MemoryHints::default(),
},
None,
))
.unwrap_or_else(|error| {
panic!("GPU required: failed to create RTX 5090 wgpu device: {error}. Fix: update the NVIDIA driver or lower unsupported requested limits.")
})
}