#[cfg(not(feature = "gpu"))]
fn main() {
eprintln!("This example requires the `gpu` feature.");
std::process::exit(1);
}
#[cfg(feature = "gpu")]
fn main() {
use std::iter;
use vyre::ir::validate;
use vyre::ops::primitive::xor::Xor;
use vyre::runtime::{bg_entry, cached_device, compile_compute_pipeline};
use wgpu::util::DeviceExt;
let (device, queue) = match cached_device() {
Ok(pair) => pair,
Err(error) => {
eprintln!("GPU unavailable: {error}");
std::process::exit(2);
}
};
let program = Xor::program();
let errors = validate(&program);
assert!(
errors.is_empty(),
"validation failed for Xor::program(): {errors:?}"
);
let wgsl =
vyre::lower::wgsl::lower(&program).expect("WGSL lowering must succeed for a valid program");
let pipeline = compile_compute_pipeline(device, "xor_example", &wgsl, "main")
.expect("pipeline compilation must succeed");
let a = vec![0xAAAAAAAAu32; 64];
let b = vec![0x55555555u32; 64];
let a_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("xor_a"),
contents: bytemuck::cast_slice(&a),
usage: wgpu::BufferUsages::STORAGE,
});
let b_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("xor_b"),
contents: bytemuck::cast_slice(&b),
usage: wgpu::BufferUsages::STORAGE,
});
let out_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("xor_out"),
size: (64 * std::mem::size_of::<u32>()) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("xor_bind_group"),
layout: &pipeline.get_bind_group_layout(0),
entries: &[
bg_entry(0, &a_buffer),
bg_entry(1, &b_buffer),
bg_entry(2, &out_buffer),
],
});
let readback = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("xor_readback"),
size: (64 * std::mem::size_of::<u32>()) as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("xor_encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("xor_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(1, 1, 1);
}
encoder.copy_buffer_to_buffer(
&out_buffer,
0,
&readback,
0,
(64 * std::mem::size_of::<u32>()) as u64,
);
queue.submit(iter::once(encoder.finish()));
let output = read_u32_buffer(device, &readback, 64);
for (i, &value) in output.iter().enumerate() {
if value != 0xFFFFFFFF {
eprintln!(
"First divergence at index {}: expected 0xFFFFFFFF, got 0x{:08X}",
i, value
);
std::process::exit(1);
}
}
println!("GPU XOR dispatch successful — all 64 elements are 0xFFFFFFFF");
}
fn read_u32_buffer(device: &wgpu::Device, buffer: &wgpu::Buffer, word_len: usize) -> Vec<u32> {
let byte_len = (word_len * std::mem::size_of::<u32>()) as u64;
let slice = buffer.slice(0..byte_len);
let (sender, receiver) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = sender.send(result);
});
let _ = device.poll(wgpu::Maintain::Wait);
receiver
.recv()
.expect("readback map callback must report completion")
.expect("readback buffer must map for reading");
let mapped = slice.get_mapped_range();
let bytes = mapped.to_vec();
drop(mapped);
buffer.unmap();
bytemuck::cast_slice::<u8, u32>(&bytes).to_vec()
}