use crate::pipeline::backend::{ConformDispatchConfig, WgslBackend};
use bytemuck::cast_slice;
use std::collections::{hash_map::DefaultHasher, HashSet};
use std::hash::{Hash, Hasher};
use std::sync::{LazyLock, RwLock};
use wgpu::util::DeviceExt;
use super::backend::WgpuBackend;
use super::byte_words::pad_to_words;
use super::capabilities::adapter_naga_capabilities;
use super::context::get_gpu;
use super::readback::wait_for_readback;
const STACK_INIT_BYTES: usize = 4096;
fn wgsl_already_validated(wgsl: &str) -> bool {
static SEEN: LazyLock<RwLock<HashSet<u64>>> = LazyLock::new(|| RwLock::new(HashSet::new()));
let mut hasher = DefaultHasher::new();
wgsl.hash(&mut hasher);
let key = hasher.finish();
{
let Ok(guard) = SEEN.read() else {
return false;
};
if guard.contains(&key) {
return true;
}
}
let Ok(mut guard) = SEEN.write() else {
return false;
};
!guard.insert(key)
}
impl WgslBackend for WgpuBackend {
fn name(&self) -> &str {
"wgpu"
}
fn version(&self) -> &str {
"24.0"
}
fn max_workgroup_invocations(&self) -> Option<u32> {
let ctx = get_gpu()?;
Some(ctx.device.limits().max_compute_invocations_per_workgroup)
}
fn dispatch(
&self,
wgsl: &str,
input: &[u8],
output_size: usize,
config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
dispatch_wgsl(wgsl, "vyre_conform_main", input, output_size, config)
}
}
impl vyre::VyreBackend for WgpuBackend {
fn id(&self) -> &'static str {
"wgpu"
}
fn dispatch(
&self,
program: &vyre::Program,
inputs: &[Vec<u8>],
_config: &vyre::DispatchConfig,
) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
let wgsl = vyre::lower::wgsl::lower(program).map_err(|error| {
vyre::BackendError::new(format!(
"failed to lower vyre IR to WGSL: {error}. Fix: provide a valid Program accepted by the WGSL lowering pipeline."
))
})?;
let input = inputs.first().map(Vec::as_slice).unwrap_or(&[]);
let output_size = output_size_from_program(program)?;
let config = ConformDispatchConfig {
workgroup_size: program.workgroup_size[0].max(1),
workgroup_count: output_size
.div_ceil(4)
.max(1)
.try_into()
.unwrap_or(u32::MAX),
convention: crate::spec::types::Convention::V1,
lookup_data: None,
buffer_init: crate::spec::types::BufferInitPolicy::default(),
};
dispatch_wgsl(&wgsl, "main", input, output_size, config)
.map(|output| vec![output])
.map_err(vyre::BackendError::new)
}
}
fn output_size_from_program(program: &vyre::Program) -> Result<usize, vyre::BackendError> {
let output = program
.buffers
.iter()
.find(|buffer| buffer.is_output())
.ok_or_else(|| {
vyre::BackendError::new(
"program has no output buffer. Fix: declare exactly one output buffer in the vyre Program.",
)
})?;
let count = usize::try_from(output.count()).map_err(|_| {
vyre::BackendError::new(
"program output element count exceeds usize. Fix: split the dispatch into smaller output buffers.",
)
})?;
Ok(count.saturating_mul(element_size_bytes(output.element())))
}
fn element_size_bytes(data_type: vyre::ir::DataType) -> usize {
match data_type {
vyre::ir::DataType::Bool
| vyre::ir::DataType::U32
| vyre::ir::DataType::I32
| vyre::ir::DataType::F32 => 4,
vyre::ir::DataType::U64 | vyre::ir::DataType::Vec2U32 => 8,
vyre::ir::DataType::Vec4U32 => 16,
vyre::ir::DataType::Bytes => 1,
_ => 4,
}
}
fn dispatch_wgsl(
wgsl: &str,
entry_point: &str,
input: &[u8],
output_size: usize,
config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
let ctx = get_gpu().ok_or_else(|| {
"Fix: no GPU adapter available. vyre-conform GPU parity tests require a GPU.".to_string()
})?;
if !wgsl_already_validated(wgsl) {
let naga_caps =
adapter_naga_capabilities(ctx.adapter_features, ctx.adapter_downlevel.clone());
match naga::front::wgsl::parse_str(wgsl) {
Ok(module) => {
if let Err(e) =
naga::valid::Validator::new(naga::valid::ValidationFlags::all(), naga_caps)
.validate(&module)
{
return Err(format!(
"Fix: WGSL shader fails naga validation: {e}. The shader parses but has semantic errors."
));
}
}
Err(e) => {
return Err(format!(
"Fix: WGSL shader fails naga parsing: {e}. The shader source is syntactically invalid."
));
}
}
}
let pipeline = vyre_wgpu::runtime::compile_compute_pipeline(
&ctx.device,
"vyre-conform pipeline",
wgsl,
entry_point,
)
.map_err(|e| format!("Fix: compute pipeline compilation failed: {e}"))?;
let input_words = pad_to_words(input);
let input_buffer = ctx
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("input"),
contents: cast_slice(&input_words),
usage: wgpu::BufferUsages::STORAGE,
});
let output_word_count = output_size.div_ceil(4).max(1);
let output_bytes = output_word_count * 4;
let init_byte = match config.buffer_init {
crate::spec::types::BufferInitPolicy::Poison => 0xCD,
_ => 0x00,
};
let stack_init = [init_byte; STACK_INIT_BYTES];
let heap_init;
let output_init = if output_bytes <= stack_init.len() {
&stack_init[..output_bytes]
} else {
heap_init = vec![init_byte; output_bytes];
heap_init.as_slice()
};
let output_buffer = ctx
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("output"),
contents: output_init,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
});
let input_len_u32 = u32::try_from(input.len()).map_err(|_| {
format!(
"Fix: input length {} exceeds u32 capacity; split the dispatch into u32-sized chunks",
input.len()
)
})?;
let output_len_u32 = u32::try_from(output_word_count).map_err(|_| {
format!(
"Fix: output_word_count {output_word_count} exceeds u32 capacity; reduce output_size"
)
})?;
let params = [
input_len_u32, output_len_u32, 0u32, 0u32, ];
let params_buffer = ctx
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("params"),
contents: cast_slice(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let readback_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("readback"),
size: output_bytes as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let bind_group_layout = pipeline.get_bind_group_layout(0);
let mut entries = Vec::with_capacity(4);
entries.push(wgpu::BindGroupEntry {
binding: 0,
resource: input_buffer.as_entire_binding(),
});
entries.push(wgpu::BindGroupEntry {
binding: 1,
resource: output_buffer.as_entire_binding(),
});
entries.push(wgpu::BindGroupEntry {
binding: 2,
resource: params_buffer.as_entire_binding(),
});
let lookup_buffer;
if let Some(ref lookup_data) = config.lookup_data {
let lookup_words = pad_to_words(lookup_data);
lookup_buffer = ctx
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("lookup"),
contents: cast_slice(&lookup_words),
usage: wgpu::BufferUsages::STORAGE,
});
entries.push(wgpu::BindGroupEntry {
binding: 3,
resource: lookup_buffer.as_entire_binding(),
});
}
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vyre-conform bind group"),
layout: &bind_group_layout,
entries: &entries,
});
let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("vyre-conform dispatch"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vyre-conform compute"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(config.workgroup_count, 1, 1);
}
encoder.copy_buffer_to_buffer(&output_buffer, 0, &readback_buffer, 0, output_bytes as u64);
ctx.queue.submit(std::iter::once(encoder.finish()));
let slice = readback_buffer.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = sender.send(result);
});
wait_for_readback(&ctx.device, receiver)?
.map_err(|e| format!("Fix: GPU readback mapping failed: {e:?}"))?;
let mapped = slice.get_mapped_range();
let result = mapped[..output_size].to_vec();
drop(mapped);
readback_buffer.unmap();
Ok(result)
}