#![cfg(feature = "gpu")]
use std::sync::mpsc;
use bytemuck::cast_slice;
use vyre::ir::{validate, BufferDecl, DataType, Expr, Node, Program};
use vyre::ops::match_ops::dfa_scan::{DfaScan, NO_MATCH};
use wgpu::util::DeviceExt;
const BYTE_CLASSES: usize = 256;
#[test]
fn manual_ir_add_dispatches_on_gpu() {
let Some((device, queue)) = cached_device() else {
return;
};
let program = add_program(false);
let output =
dispatch_three_buffer_u32(device, queue, &program, &[1, 2, 3, 4], &[10, 20, 30, 40]);
assert_eq!(output, [11, 22, 33, 44]);
}
#[test]
fn call_inlined_add_dispatches_on_gpu() {
let Some((device, queue)) = cached_device() else {
return;
};
let program = add_program(true);
let output =
dispatch_three_buffer_u32(device, queue, &program, &[1, 2, 3, 4], &[10, 20, 30, 40]);
assert_eq!(output, [11, 22, 33, 44]);
}
#[test]
fn dfa_scan_ir_dispatches_on_gpu() {
let Some((device, queue)) = cached_device() else {
return;
};
let program = DfaScan::program();
assert_validation_passes(&program);
let pipeline = compile_program(device, &program, "gpu_dispatch_dfa_scan");
let input = b"xxabxxcdxx";
let input_words = pack_input_words(input);
let state_count = 5;
let mut transitions = vec![0u32; state_count * BYTE_CLASSES];
transitions[b'a' as usize] = 1;
transitions[BYTE_CLASSES + b'b' as usize] = 2;
transitions[b'c' as usize] = 3;
transitions[3 * BYTE_CLASSES + b'd' as usize] = 4;
let mut accept_map = vec![NO_MATCH; state_count];
accept_map[2] = 0;
accept_map[4] = 1;
let max_matches = 8u32;
let match_words = usize::try_from(max_matches)
.expect("max_matches fits usize")
.checked_mul(3)
.expect("match buffer word count fits usize");
let params = [
u32::try_from(input.len()).expect("input len fits u32"),
u32::try_from(state_count).expect("state count fits u32"),
max_matches,
0,
];
let input_buffer = storage_init(device, "gpu_dispatch_dfa_input", &input_words);
let transition_buffer = storage_init(device, "gpu_dispatch_dfa_transitions", &transitions);
let accept_buffer = storage_init(device, "gpu_dispatch_dfa_accept", &accept_map);
let matches_buffer = storage_empty(
device,
"gpu_dispatch_dfa_matches",
match_words,
wgpu::BufferUsages::COPY_SRC,
);
let count_buffer = storage_empty(
device,
"gpu_dispatch_dfa_count",
1,
wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST,
);
let params_buffer = storage_init(device, "gpu_dispatch_dfa_params", ¶ms);
let pattern_lengths_buffer = storage_init(device, "gpu_dispatch_dfa_pattern_lengths", &[2, 2]);
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("gpu_dispatch_dfa_bind_group"),
layout: &pipeline.get_bind_group_layout(0),
entries: &[
vyre_wgpu::runtime::bg_entry(0, &input_buffer),
vyre_wgpu::runtime::bg_entry(1, &transition_buffer),
vyre_wgpu::runtime::bg_entry(2, &accept_buffer),
vyre_wgpu::runtime::bg_entry(3, &matches_buffer),
vyre_wgpu::runtime::bg_entry(4, &count_buffer),
vyre_wgpu::runtime::bg_entry(5, ¶ms_buffer),
vyre_wgpu::runtime::bg_entry(6, &pattern_lengths_buffer),
],
});
let count_readback = readback_buffer(device, "gpu_dispatch_dfa_count_readback", 4);
let matches_readback = readback_buffer(
device,
"gpu_dispatch_dfa_matches_readback",
u64::from(max_matches) * 12,
);
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("gpu_dispatch_dfa_encoder"),
});
encoder.clear_buffer(&count_buffer, 0, None);
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("gpu_dispatch_dfa_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(1, 1, 1);
}
encoder.copy_buffer_to_buffer(&count_buffer, 0, &count_readback, 0, 4);
encoder.copy_buffer_to_buffer(
&matches_buffer,
0,
&matches_readback,
0,
u64::from(max_matches) * 12,
);
queue.submit(std::iter::once(encoder.finish()));
let reported = read_u32_buffer(device, &count_readback, 1)[0];
let captured = reported.min(max_matches);
assert!(
captured >= 2,
"expected at least 2 DFA matches, got {captured}"
);
let match_words = read_u32_buffer(
device,
&matches_readback,
usize::try_from(captured).expect("captured fits usize") * 3,
);
let matches = match_words
.chunks_exact(3)
.map(|fields| vyre::Match::new(fields[0], fields[1], fields[2]))
.collect::<Vec<_>>();
let found_pattern_ids: std::collections::BTreeSet<u32> =
matches.iter().map(|m| m.pattern_id).collect();
assert!(
found_pattern_ids.contains(&0),
"pattern 0 ('ab') not found in DFA matches: {matches:?}"
);
assert!(
found_pattern_ids.contains(&1),
"pattern 1 ('cd') not found in DFA matches: {matches:?}"
);
}
fn add_program(use_call: bool) -> Program {
let idx = Expr::var("idx");
let sum = if use_call {
Expr::call(
"primitive.math.add",
vec![Expr::load("a", idx.clone()), Expr::load("b", idx.clone())],
)
} else {
Expr::add(Expr::load("a", idx.clone()), Expr::load("b", idx.clone()))
};
Program::new(
vec![
BufferDecl::read("a", 0, DataType::U32),
BufferDecl::read("b", 1, DataType::U32),
BufferDecl::output("out", 2, DataType::U32),
],
[4, 1, 1],
vec![
Node::let_bind("idx", Expr::gid_x()),
Node::if_then(
Expr::lt(idx.clone(), Expr::buf_len("out")),
vec![Node::store("out", idx, sum)],
),
],
)
}
fn dispatch_three_buffer_u32(
device: &wgpu::Device,
queue: &wgpu::Queue,
program: &Program,
a: &[u32],
b: &[u32],
) -> Vec<u32> {
assert_eq!(a.len(), b.len());
assert_validation_passes(program);
let pipeline = compile_program(device, program, "gpu_dispatch_add");
let a_buffer = storage_init(device, "gpu_dispatch_a", a);
let b_buffer = storage_init(device, "gpu_dispatch_b", b);
let out_buffer = storage_empty(
device,
"gpu_dispatch_out",
a.len(),
wgpu::BufferUsages::COPY_SRC,
);
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("gpu_dispatch_add_bind_group"),
layout: &pipeline.get_bind_group_layout(0),
entries: &[
vyre_wgpu::runtime::bg_entry(0, &a_buffer),
vyre_wgpu::runtime::bg_entry(1, &b_buffer),
vyre_wgpu::runtime::bg_entry(2, &out_buffer),
],
});
let readback = readback_buffer(device, "gpu_dispatch_add_readback", byte_len(a.len()));
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("gpu_dispatch_add_encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("gpu_dispatch_add_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(1, 1, 1);
}
encoder.copy_buffer_to_buffer(&out_buffer, 0, &readback, 0, byte_len(a.len()));
queue.submit(std::iter::once(encoder.finish()));
read_u32_buffer(device, &readback, a.len())
}
fn cached_device() -> Option<&'static (wgpu::Device, wgpu::Queue)> {
match vyre_wgpu::runtime::cached_device() {
Ok(pair) => Some(pair),
Err(error) => {
panic!("GPU required on this machine (RTX 5090 / 4090 available per project invariant) — do not silently skip: gpu_dispatch integration test: {error}");
None
}
}
}
fn assert_validation_passes(program: &Program) {
let errors = validate(program);
assert!(
errors.is_empty(),
"IR validation failed for GPU dispatch smoke test: {errors:?}"
);
}
fn compile_program(
device: &wgpu::Device,
program: &Program,
label: &'static str,
) -> wgpu::ComputePipeline {
let wgsl = vyre::lower::wgsl::lower(program)
.unwrap_or_else(|error| panic!("WGSL lowering failed for {label}: {error}"));
vyre_wgpu::runtime::compile_compute_pipeline(device, label, &wgsl, "main")
.unwrap_or_else(|error| panic!("WGSL compilation failed for {label}: {error}"))
}
fn storage_init(device: &wgpu::Device, label: &'static str, data: &[u32]) -> wgpu::Buffer {
storage_init_with_usage(device, label, data, wgpu::BufferUsages::STORAGE)
}
fn storage_init_with_usage(
device: &wgpu::Device,
label: &'static str,
data: &[u32],
usage: wgpu::BufferUsages,
) -> wgpu::Buffer {
device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some(label),
contents: cast_slice(data),
usage,
})
}
fn storage_empty(
device: &wgpu::Device,
label: &'static str,
word_len: usize,
extra_usage: wgpu::BufferUsages,
) -> wgpu::Buffer {
device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: byte_len(word_len),
usage: wgpu::BufferUsages::STORAGE | extra_usage,
mapped_at_creation: false,
})
}
fn readback_buffer(device: &wgpu::Device, label: &'static str, size: u64) -> wgpu::Buffer {
device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
})
}
fn read_u32_buffer(device: &wgpu::Device, buffer: &wgpu::Buffer, word_len: usize) -> Vec<u32> {
let byte_len = byte_len(word_len);
let slice = buffer.slice(0..byte_len);
let (sender, receiver) = mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = sender.send(result);
});
match device.poll(wgpu::Maintain::Wait) {
wgpu::MaintainResult::Ok | wgpu::MaintainResult::SubmissionQueueEmpty => {}
}
receiver
.recv()
.expect("readback map callback must report completion")
.expect("readback buffer must map for reading");
let mapped = slice.get_mapped_range();
let bytes = mapped.to_vec();
drop(mapped);
buffer.unmap();
cast_slice::<u8, u32>(&bytes).to_vec()
}
fn byte_len(word_len: usize) -> u64 {
let bytes = word_len
.checked_mul(std::mem::size_of::<u32>())
.expect("u32 buffer byte length must not overflow usize");
u64::try_from(bytes).expect("u32 buffer byte length must fit u64")
}
fn pack_input_words(input: &[u8]) -> Vec<u32> {
input
.chunks(4)
.map(|chunk| {
let mut bytes = [0u8; 4];
bytes[..chunk.len()].copy_from_slice(chunk);
u32::from_le_bytes(bytes)
})
.collect()
}