#[cfg(test)]
mod tests {
use super::*;
use crate::pipeline::backend::{require_gpu, ConformDispatchConfig, WgslBackend};
use crate::wgpu_backend::WgpuBackend;
const WORKGROUP_SWEEP: &[u32] = &[1, 8, 64, 256];
fn run_at_all_workgroups(
backend: &WgpuBackend,
op_wgsl: &str,
input: &[u8],
output_size: usize,
base_config: &ConformDispatchConfig,
cpu_result: &[u8],
context: &str,
) {
for &workgroup_size in WORKGROUP_SWEEP {
if let Some(max) = backend.max_workgroup_invocations() {
assert!(
workgroup_size <= max,
"{context}: GPU max workgroup invocations {max} cannot cover required sweep size {workgroup_size}. \
Fix: run on a GPU that supports the vyre conformance workgroup sweep."
);
}
let mut config = base_config.clone();
config.workgroup_size = workgroup_size;
let shader = crate::pipeline::backend::wrap_shader(op_wgsl, &config);
let gpu_result = backend
.dispatch(&shader, input, output_size, config)
.unwrap_or_else(|e| {
panic!("{context}: GPU dispatch failed at workgroup_size={workgroup_size}: {e}")
});
assert_eq!(
gpu_result, cpu_result,
"{context}: GPU/CPU parity failure at workgroup_size={workgroup_size}: gpu={:?} cpu={:?}",
gpu_result, cpu_result
);
}
}
fn verify_binary_parity(
backend: &WgpuBackend,
op_wgsl: &str,
cpu_fn: fn(&[u8]) -> Vec<u8>,
test_pairs: &[(u32, u32)],
op_name: &str,
) {
assert!(
!test_pairs.is_empty(),
"Fix: {op_name} must define at least one boundary pair"
);
crate::pipeline::backend::run_with_all_init_policies(|init| {
let mut config = ConformDispatchConfig::default();
config.buffer_init = init;
for &(a, b) in test_pairs {
let mut input = Vec::with_capacity(8);
input.extend_from_slice(&a.to_le_bytes());
input.extend_from_slice(&b.to_le_bytes());
let cpu_result = cpu_fn(&input);
run_at_all_workgroups(
backend,
op_wgsl,
&input,
4,
&config,
&cpu_result,
&format!("{op_name} ({a:#010X}, {b:#010X}) with {init:?}"),
);
}
});
}
fn verify_unary_parity(
backend: &WgpuBackend,
op_wgsl: &str,
cpu_fn: fn(&[u8]) -> Vec<u8>,
test_values: &[u32],
op_name: &str,
) {
assert!(
!test_values.is_empty(),
"Fix: {op_name} must define at least one boundary value"
);
crate::pipeline::backend::run_with_all_init_policies(|init| {
let mut config = ConformDispatchConfig::default();
config.buffer_init = init;
for &a in test_values {
let input = a.to_le_bytes().to_vec();
let cpu_result = cpu_fn(&input);
run_at_all_workgroups(
backend,
op_wgsl,
&input,
4,
&config,
&cpu_result,
&format!("{op_name} {a:#010X} with {init:?}"),
);
}
});
}
const BINARY_BOUNDARIES: &[(u32, u32)] = &[
(0, 0),
(0, 1),
(1, 0),
(1, 1),
(0, u32::MAX),
(u32::MAX, 0),
(u32::MAX, u32::MAX),
(u32::MAX, 1),
(1, u32::MAX),
(0x80000000, 0),
(0, 0x80000000),
(0x80000000, 0x80000000),
(0xDEADBEEF, 0xCAFEBABE),
(0x55555555, 0xAAAAAAAA),
(0xF0F0F0F0, 0x0F0F0F0F),
(0x12345678, 0x9ABCDEF0),
(31, 1),
(32, 1),
(33, 1),
(1, 31),
(1, 32),
(1, 33),
];
const UNARY_BOUNDARIES: &[u32] = &[
0,
1,
2,
31,
32,
0xFF,
0x100,
0x8000,
0x80000000,
0x7FFFFFFF,
u32::MAX,
0xDEADBEEF,
0x55555555,
0xAAAAAAAA,
0xF0F0F0F0,
];
fn verify_binary_exhaustive_u8(
backend: &WgpuBackend,
op_wgsl: &str,
cpu_fn: fn(&[u8]) -> Vec<u8>,
op_name: &str,
) {
crate::pipeline::backend::run_with_all_init_policies(|init| {
let mut config = ConformDispatchConfig::default();
config.buffer_init = init;
for a in 0u32..=255 {
for b in 0u32..=255 {
let mut input = Vec::with_capacity(8);
input.extend_from_slice(&a.to_le_bytes());
input.extend_from_slice(&b.to_le_bytes());
let cpu_result = cpu_fn(&input);
run_at_all_workgroups(
backend,
op_wgsl,
&input,
4,
&config,
&cpu_result,
&format!("{op_name} exhaustive ({a}, {b}) with {init:?}"),
);
}
}
});
}
fn verify_unary_exhaustive_u8(
backend: &WgpuBackend,
op_wgsl: &str,
cpu_fn: fn(&[u8]) -> Vec<u8>,
op_name: &str,
) {
crate::pipeline::backend::run_with_all_init_policies(|init| {
let mut config = ConformDispatchConfig::default();
config.buffer_init = init;
for a in 0u32..=255 {
let input = a.to_le_bytes().to_vec();
let cpu_result = cpu_fn(&input);
run_at_all_workgroups(
backend,
op_wgsl,
&input,
4,
&config,
&cpu_result,
&format!("{op_name} exhaustive {a} with {init:?}"),
);
}
});
}
use crate::spec::primitive;
macro_rules! binary_gpu_parity {
($name:ident, $mod:ident, $label:expr) => {
#[test]
fn $name() {
let backend = require_gpu().expect("vyre-conform test needs a GPU adapter");
let spec = primitive::$mod::spec();
verify_binary_parity(
&backend,
&(spec.wgsl_fn)(),
spec.cpu_fn,
BINARY_BOUNDARIES,
$label,
);
}
};
}
macro_rules! unary_gpu_parity {
($name:ident, $mod:ident, $label:expr) => {
#[test]
fn $name() {
let backend = require_gpu().expect("vyre-conform test needs a GPU adapter");
let spec = primitive::$mod::spec();
verify_unary_parity(
&backend,
&(spec.wgsl_fn)(),
spec.cpu_fn,
UNARY_BOUNDARIES,
$label,
);
}
};
}
macro_rules! binary_exhaustive_u8 {
($name:ident, $mod:ident, $label:expr) => {
#[test]
fn $name() {
let backend = require_gpu().expect("vyre-conform test needs a GPU adapter");
let spec = primitive::$mod::spec();
verify_binary_exhaustive_u8(&backend, &(spec.wgsl_fn)(), spec.cpu_fn, $label);
}
};
}
macro_rules! unary_exhaustive_u8 {
($name:ident, $mod:ident, $label:expr) => {
#[test]
fn $name() {
let backend = require_gpu().expect("vyre-conform test needs a GPU adapter");
let spec = primitive::$mod::spec();
verify_unary_exhaustive_u8(&backend, &(spec.wgsl_fn)(), spec.cpu_fn, $label);
}
};
}
binary_gpu_parity!(gpu_parity_xor, xor, "xor");
binary_gpu_parity!(gpu_parity_and, and, "and");
binary_gpu_parity!(gpu_parity_or, or, "or");
unary_gpu_parity!(gpu_parity_not, not, "not");
binary_gpu_parity!(gpu_parity_shl, shl, "shl");
binary_gpu_parity!(gpu_parity_shr, shr, "shr");
binary_gpu_parity!(gpu_parity_rotl, rotl, "rotl");
binary_gpu_parity!(gpu_parity_rotr, rotr, "rotr");
unary_gpu_parity!(gpu_parity_popcount, popcount, "popcount");
unary_gpu_parity!(gpu_parity_clz, clz, "clz");
unary_gpu_parity!(gpu_parity_ctz, ctz, "ctz");
unary_gpu_parity!(gpu_parity_reverse_bits, reverse_bits, "reverse_bits");
binary_gpu_parity!(gpu_parity_extract_bits, extract_bits, "extract_bits");
binary_gpu_parity!(gpu_parity_insert_bits, insert_bits, "insert_bits");
binary_gpu_parity!(gpu_parity_add, add, "add");
binary_gpu_parity!(gpu_parity_sub, sub, "sub");
binary_gpu_parity!(gpu_parity_mul, mul, "mul");
binary_gpu_parity!(gpu_parity_div, div, "div");
binary_gpu_parity!(gpu_parity_mod, mod_op, "mod");
binary_gpu_parity!(gpu_parity_min, min, "min");
binary_gpu_parity!(gpu_parity_max, max, "max");
binary_gpu_parity!(gpu_parity_clamp, clamp, "clamp");
unary_gpu_parity!(gpu_parity_abs, abs, "abs");
unary_gpu_parity!(gpu_parity_negate, negate, "negate");
binary_gpu_parity!(gpu_parity_eq, eq, "eq");
binary_gpu_parity!(gpu_parity_ne, ne, "ne");
binary_gpu_parity!(gpu_parity_lt, lt, "lt");
binary_gpu_parity!(gpu_parity_gt, gt, "gt");
binary_gpu_parity!(gpu_parity_le, le, "le");
binary_gpu_parity!(gpu_parity_ge, ge, "ge");
binary_gpu_parity!(gpu_parity_select, select, "select");
unary_gpu_parity!(gpu_parity_logical_not, logical_not, "logical_not");
binary_exhaustive_u8!(exhaustive_u8_xor, xor, "xor");
binary_exhaustive_u8!(exhaustive_u8_and, and, "and");
binary_exhaustive_u8!(exhaustive_u8_or, or, "or");
unary_exhaustive_u8!(exhaustive_u8_not, not, "not");
binary_exhaustive_u8!(exhaustive_u8_shl, shl, "shl");
binary_exhaustive_u8!(exhaustive_u8_shr, shr, "shr");
binary_exhaustive_u8!(exhaustive_u8_add, add, "add");
binary_exhaustive_u8!(exhaustive_u8_sub, sub, "sub");
binary_exhaustive_u8!(exhaustive_u8_mul, mul, "mul");
binary_exhaustive_u8!(exhaustive_u8_div, div, "div");
binary_exhaustive_u8!(exhaustive_u8_mod, mod_op, "mod");
binary_exhaustive_u8!(exhaustive_u8_eq, eq, "eq");
binary_exhaustive_u8!(exhaustive_u8_ne, ne, "ne");
binary_exhaustive_u8!(exhaustive_u8_lt, lt, "lt");
binary_exhaustive_u8!(exhaustive_u8_gt, gt, "gt");
binary_exhaustive_u8!(exhaustive_u8_le, le, "le");
binary_exhaustive_u8!(exhaustive_u8_ge, ge, "ge");
unary_exhaustive_u8!(exhaustive_u8_popcount, popcount, "popcount");
unary_exhaustive_u8!(exhaustive_u8_clz, clz, "clz");
unary_exhaustive_u8!(exhaustive_u8_ctz, ctz, "ctz");
unary_exhaustive_u8!(exhaustive_u8_reverse_bits, reverse_bits, "reverse_bits");
unary_exhaustive_u8!(exhaustive_u8_negate, negate, "negate");
unary_exhaustive_u8!(exhaustive_u8_abs, abs, "abs");
unary_exhaustive_u8!(exhaustive_u8_logical_not, logical_not, "logical_not");
}