#![cfg(feature = "gpu")]
use std::sync::mpsc;
use vyre_conform::algebra::verify_gpu_laws_witnessed;
use vyre_conform::backend::{DispatchConfig, VyreBackend};
use vyre_conform::specs::primitive;
use vyre_conform::types::Convention;
use wgpu::util::DeviceExt;
const GPU_LAW_WITNESSES: u64 = 128;
struct WgpuBackend {
device: wgpu::Device,
queue: wgpu::Queue,
}
impl WgpuBackend {
fn new_if_available() -> Option<Self> {
let instance = wgpu::Instance::default();
let adapter =
pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions::default()))?;
let (device, queue) =
pollster::block_on(adapter.request_device(&wgpu::DeviceDescriptor::default(), None))
.ok()?;
Some(Self { device, queue })
}
}
impl VyreBackend for WgpuBackend {
fn name(&self) -> &str {
"wgpu"
}
fn max_convention(&self) -> Convention {
Convention::V1
}
fn dispatch(
&self,
wgsl: &str,
input: &[u8],
output_size: usize,
config: DispatchConfig,
) -> Result<Vec<u8>, String> {
let input_padded = padded(input);
let output_padded_size = padded_size(output_size);
let params = params_bytes(
(input_padded.len() / 4) as u32,
(output_padded_size / 4) as u32,
);
let input_buffer = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("gpu algebra input"),
contents: &input_padded,
usage: wgpu::BufferUsages::STORAGE,
});
let output_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("gpu algebra output"),
size: output_padded_size as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let readback_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("gpu algebra readback"),
size: output_padded_size as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let params_buffer = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("gpu algebra params"),
contents: ¶ms,
usage: wgpu::BufferUsages::UNIFORM,
});
let module = self
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("gpu algebra shader"),
source: wgpu::ShaderSource::Wgsl(wgsl.into()),
});
let pipeline = self
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("gpu algebra pipeline"),
layout: None,
module: &module,
entry_point: Some("vyre_conform_main"),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("gpu algebra bind group"),
layout: &pipeline.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: input_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: output_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: params_buffer.as_entire_binding(),
},
],
});
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("gpu algebra encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("gpu algebra pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(config.workgroup_count, 1, 1);
}
encoder.copy_buffer_to_buffer(
&output_buffer,
0,
&readback_buffer,
0,
output_padded_size as u64,
);
self.queue.submit(std::iter::once(encoder.finish()));
readback(
&self.device,
&readback_buffer,
output_padded_size,
output_size,
)
}
}
#[test]
fn all_primitive_laws_hold_on_wgpu_when_available() {
let Some(backend) = WgpuBackend::new_if_available() else {
eprintln!("skipping GPU algebra L2 conformance: no wgpu adapter/device available");
return;
};
let specs = primitive::specs();
assert_eq!(
specs.len(),
32,
"primitive registry must contain all 32 ops"
);
let mut total_laws = 0u64;
let mut total_cases = 0u64;
for spec in &specs {
let results = verify_gpu_laws_witnessed(&backend, spec, GPU_LAW_WITNESSES);
for result in results {
if result.cases_tested == 0 {
continue;
}
total_laws += 1;
total_cases += result.cases_tested;
if let Some(violation) = result.violation {
panic!(
"GPU LAW FAILED: {}\nOp: {}\nLaw: {}\na={}, b={}, c={}\ngpu/lhs={}, expected/rhs={}",
violation.message,
violation.op_id,
violation.law,
violation.a,
violation.b,
violation.c,
violation.lhs,
violation.rhs,
);
}
}
}
assert!(total_laws > 0, "no primitive laws were verified on GPU");
assert!(total_cases > 1_000, "too few GPU law cases: {total_cases}");
}
fn readback(
device: &wgpu::Device,
readback_buffer: &wgpu::Buffer,
padded_size: usize,
output_size: usize,
) -> Result<Vec<u8>, String> {
let slice = readback_buffer.slice(0..padded_size as u64);
let (sender, receiver) = mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = sender.send(result);
});
let _ = device.poll(wgpu::Maintain::Wait);
receiver
.recv()
.map_err(|error| {
format!("failed to receive readback: {error}. Fix: check GPU device availability")
})?
.map_err(|error| {
format!("failed to map readback buffer: {error:?}. Fix: check buffer usage flags")
})?;
let data = slice.get_mapped_range();
let result = data[..output_size].to_vec();
drop(data);
readback_buffer.unmap();
Ok(result)
}
fn padded(input: &[u8]) -> Vec<u8> {
let size = padded_size(input.len());
let mut bytes = vec![0u8; size];
bytes[..input.len()].copy_from_slice(input);
bytes
}
fn padded_size(size: usize) -> usize {
((size + 3) & !3).max(16)
}
fn params_bytes(input_len: u32, output_len: u32) -> [u8; 16] {
let mut bytes = [0u8; 16];
bytes[0..4].copy_from_slice(&input_len.to_le_bytes());
bytes[4..8].copy_from_slice(&output_len.to_le_bytes());
bytes
}