use super::{ApproximateMatch, Comparator, ExactMatch, SetEqualityMatch, UnorderedMatch};
impl Comparator for ExactMatch {
fn compare(&self, gpu: &[u8], cpu: &[u8]) -> Result<(), String> {
if gpu == cpu {
return Ok(());
}
if gpu.len() != cpu.len() {
debug_diagnostic(format_args!(
"exact length mismatch: gpu_len={}, cpu_len={}",
gpu.len(),
cpu.len()
));
return Err(
"length mismatch. Fix: return exactly the same byte count as the CPU reference."
.to_string(),
);
}
for (idx, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
if g != c {
debug_diagnostic(format_args!(
"exact byte mismatch: index={idx}, gpu=0x{g:02x}, cpu=0x{c:02x}"
));
return Err(format!(
"mismatch at offset >= {}. Fix: inspect the corresponding WGSL output region.",
aligned_offset(idx)
));
}
}
Ok(())
}
}
impl Comparator for UnorderedMatch {
fn compare(&self, gpu: &[u8], cpu: &[u8]) -> Result<(), String> {
let element_size = inferred_unordered_element_size(gpu.len());
SetEqualityMatch { element_size }.compare(gpu, cpu)
}
}
impl Comparator for ApproximateMatch {
fn compare(&self, gpu: &[u8], cpu: &[u8]) -> Result<(), String> {
if gpu.len() != cpu.len() {
debug_diagnostic(format_args!(
"approximate length mismatch: gpu_len={}, cpu_len={}",
gpu.len(),
cpu.len()
));
return Err(
"approximate length mismatch. Fix: return the specified byte length.".to_string(),
);
}
if gpu.len() % 4 == 0 {
return compare_f32_ulps(gpu, cpu, self.epsilon);
}
if gpu.len() % 2 == 0 {
return compare_u16_ulps(gpu, cpu, self.epsilon);
}
Err("approximate output length is not divisible by 2 or 4. \
Fix: declare F16, BF16, or F32 output bytes for ULP comparison."
.to_string())
}
}
impl Comparator for SetEqualityMatch {
fn compare(&self, gpu: &[u8], cpu: &[u8]) -> Result<(), String> {
if gpu.len() != cpu.len() {
debug_diagnostic(format_args!(
"set-equality length mismatch: gpu_len={}, cpu_len={}",
gpu.len(),
cpu.len()
));
return Err(
"set-equality length mismatch. Fix: return the same total byte count.".to_string(),
);
}
if self.element_size == 0 {
return if gpu.is_empty() && cpu.is_empty() {
Ok(())
} else {
Err("set-equality: element_size is 0 but output is non-empty. \
Fix: set element_size > 0."
.to_string())
};
}
if gpu.len() % self.element_size != 0 {
debug_diagnostic(format_args!(
"set-equality divisibility mismatch: output_len={}, element_size={}",
gpu.len(),
self.element_size
));
return Err(
"set-equality: output length is not a multiple of element size. Fix: output must contain whole elements."
.to_string(),
);
}
let mut gpu_elements: Vec<&[u8]> = gpu.chunks(self.element_size).collect();
let mut cpu_elements: Vec<&[u8]> = cpu.chunks(self.element_size).collect();
gpu_elements.sort_unstable();
cpu_elements.sort_unstable();
for (idx, (g, c)) in gpu_elements.iter().zip(cpu_elements.iter()).enumerate() {
if g != c {
debug_diagnostic(format_args!(
"set-equality element mismatch: sorted_index={idx}, gpu={g:02x?}, cpu={c:02x?}"
));
return Err(format!(
"set-equality: mismatch at offset >= {} after sorting. Fix: verify element content, not just count.",
aligned_offset(idx.saturating_mul(self.element_size))
));
}
}
Ok(())
}
}
fn aligned_offset(idx: usize) -> usize {
idx / 64 * 64
}
fn debug_diagnostic(args: std::fmt::Arguments<'_>) {
if std::env::var_os("VYRE_CONFORM_DEBUG_DIAGNOSTICS").as_deref()
== Some(std::ffi::OsStr::new("1"))
{
tracing::debug!("{args}");
}
}
fn inferred_unordered_element_size(len: usize) -> usize {
if len != 0 && len % 8 == 0 {
8
} else if len != 0 && len % 4 == 0 {
4
} else {
1
}
}
fn compare_f32_ulps(gpu: &[u8], cpu: &[u8], max_ulps: u32) -> Result<(), String> {
for (index, (gpu_chunk, cpu_chunk)) in gpu.chunks_exact(4).zip(cpu.chunks_exact(4)).enumerate()
{
let gpu_bits = u32::from_le_bytes([gpu_chunk[0], gpu_chunk[1], gpu_chunk[2], gpu_chunk[3]]);
let cpu_bits = u32::from_le_bytes([cpu_chunk[0], cpu_chunk[1], cpu_chunk[2], cpu_chunk[3]]);
let distance = f32_ulp_distance(gpu_bits, cpu_bits);
if distance > max_ulps {
return Err(format!(
"approximate mismatch at f32 element {index}: ULP distance {distance} > max_ulps {max_ulps}. \
Fix: reduce numeric drift or update the declared tolerance."
));
}
}
Ok(())
}
fn compare_u16_ulps(gpu: &[u8], cpu: &[u8], max_ulps: u32) -> Result<(), String> {
for (index, (gpu_chunk, cpu_chunk)) in gpu.chunks_exact(2).zip(cpu.chunks_exact(2)).enumerate()
{
let gpu_bits = u16::from_le_bytes([gpu_chunk[0], gpu_chunk[1]]);
let cpu_bits = u16::from_le_bytes([cpu_chunk[0], cpu_chunk[1]]);
let distance = f16_ulp_distance(gpu_bits, cpu_bits);
if distance > max_ulps {
return Err(format!(
"approximate mismatch at 16-bit float element {index}: ULP distance {distance} > max_ulps {max_ulps}. \
Fix: reduce numeric drift or update the declared tolerance."
));
}
}
Ok(())
}
fn f32_ulp_distance(left: u32, right: u32) -> u32 {
let left_value = f32::from_bits(left);
let right_value = f32::from_bits(right);
if left == right {
return 0;
}
if left_value == right_value {
return 0;
}
if left_value.is_nan() && right_value.is_nan() {
return 0;
}
f32_ordered(left).abs_diff(f32_ordered(right))
}
fn f32_ordered(bits: u32) -> u32 {
if bits & 0x8000_0000 == 0 {
bits | 0x8000_0000
} else {
!bits
}
}
fn f16_ulp_distance(left: u16, right: u16) -> u32 {
if left == right || (left & 0x7fff == 0 && right & 0x7fff == 0) {
return 0;
}
if u16_float_is_nan(left) && u16_float_is_nan(right) {
return 0;
}
u32::from(f16_ordered(left).abs_diff(f16_ordered(right)))
}
fn u16_float_is_nan(bits: u16) -> bool {
bits & 0x7c00 == 0x7c00 && bits & 0x03ff != 0
}
fn f16_ordered(bits: u16) -> u16 {
if bits & 0x8000 == 0 {
bits | 0x8000
} else {
!bits
}
}