vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Comparator strategy implementations.

use super::{ApproximateMatch, Comparator, ExactMatch, SetEqualityMatch, UnorderedMatch};

impl Comparator for ExactMatch {
    fn compare(&self, gpu: &[u8], cpu: &[u8]) -> Result<(), String> {
        if gpu == cpu {
            return Ok(());
        }
        if gpu.len() != cpu.len() {
            debug_diagnostic(format_args!(
                "exact length mismatch: gpu_len={}, cpu_len={}",
                gpu.len(),
                cpu.len()
            ));
            return Err(
                "length mismatch. Fix: return exactly the same byte count as the CPU reference."
                    .to_string(),
            );
        }
        for (idx, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
            if g != c {
                debug_diagnostic(format_args!(
                    "exact byte mismatch: index={idx}, gpu=0x{g:02x}, cpu=0x{c:02x}"
                ));
                return Err(format!(
                    "mismatch at offset >= {}. Fix: inspect the corresponding WGSL output region.",
                    aligned_offset(idx)
                ));
            }
        }
        Ok(())
    }
}

impl Comparator for UnorderedMatch {
    fn compare(&self, gpu: &[u8], cpu: &[u8]) -> Result<(), String> {
        let element_size = inferred_unordered_element_size(gpu.len());
        SetEqualityMatch { element_size }.compare(gpu, cpu)
    }
}

impl Comparator for ApproximateMatch {
    fn compare(&self, gpu: &[u8], cpu: &[u8]) -> Result<(), String> {
        if gpu.len() != cpu.len() {
            debug_diagnostic(format_args!(
                "approximate length mismatch: gpu_len={}, cpu_len={}",
                gpu.len(),
                cpu.len()
            ));
            return Err(
                "approximate length mismatch. Fix: return the specified byte length.".to_string(),
            );
        }
        if gpu.len() % 4 == 0 {
            return compare_f32_ulps(gpu, cpu, self.epsilon);
        }
        if gpu.len() % 2 == 0 {
            return compare_u16_ulps(gpu, cpu, self.epsilon);
        }
        Err("approximate output length is not divisible by 2 or 4. \
             Fix: declare F16, BF16, or F32 output bytes for ULP comparison."
            .to_string())
    }
}

impl Comparator for SetEqualityMatch {
    fn compare(&self, gpu: &[u8], cpu: &[u8]) -> Result<(), String> {
        if gpu.len() != cpu.len() {
            debug_diagnostic(format_args!(
                "set-equality length mismatch: gpu_len={}, cpu_len={}",
                gpu.len(),
                cpu.len()
            ));
            return Err(
                "set-equality length mismatch. Fix: return the same total byte count.".to_string(),
            );
        }
        if self.element_size == 0 {
            return if gpu.is_empty() && cpu.is_empty() {
                Ok(())
            } else {
                Err("set-equality: element_size is 0 but output is non-empty. \
                     Fix: set element_size > 0."
                    .to_string())
            };
        }
        if gpu.len() % self.element_size != 0 {
            debug_diagnostic(format_args!(
                "set-equality divisibility mismatch: output_len={}, element_size={}",
                gpu.len(),
                self.element_size
            ));
            return Err(
                "set-equality: output length is not a multiple of element size. Fix: output must contain whole elements."
                    .to_string(),
            );
        }

        let mut gpu_elements: Vec<&[u8]> = gpu.chunks(self.element_size).collect();
        let mut cpu_elements: Vec<&[u8]> = cpu.chunks(self.element_size).collect();
        gpu_elements.sort_unstable();
        cpu_elements.sort_unstable();

        for (idx, (g, c)) in gpu_elements.iter().zip(cpu_elements.iter()).enumerate() {
            if g != c {
                debug_diagnostic(format_args!(
                    "set-equality element mismatch: sorted_index={idx}, gpu={g:02x?}, cpu={c:02x?}"
                ));
                return Err(format!(
                    "set-equality: mismatch at offset >= {} after sorting. Fix: verify element content, not just count.",
                    aligned_offset(idx.saturating_mul(self.element_size))
                ));
            }
        }
        Ok(())
    }
}

fn aligned_offset(idx: usize) -> usize {
    idx / 64 * 64
}

fn debug_diagnostic(args: std::fmt::Arguments<'_>) {
    if std::env::var_os("VYRE_CONFORM_DEBUG_DIAGNOSTICS").as_deref()
        == Some(std::ffi::OsStr::new("1"))
    {
        tracing::debug!("{args}");
    }
}

fn inferred_unordered_element_size(len: usize) -> usize {
    if len != 0 && len % 8 == 0 {
        8
    } else if len != 0 && len % 4 == 0 {
        4
    } else {
        1
    }
}

fn compare_f32_ulps(gpu: &[u8], cpu: &[u8], max_ulps: u32) -> Result<(), String> {
    for (index, (gpu_chunk, cpu_chunk)) in gpu.chunks_exact(4).zip(cpu.chunks_exact(4)).enumerate()
    {
        let gpu_bits = u32::from_le_bytes([gpu_chunk[0], gpu_chunk[1], gpu_chunk[2], gpu_chunk[3]]);
        let cpu_bits = u32::from_le_bytes([cpu_chunk[0], cpu_chunk[1], cpu_chunk[2], cpu_chunk[3]]);
        let distance = f32_ulp_distance(gpu_bits, cpu_bits);
        if distance > max_ulps {
            return Err(format!(
                "approximate mismatch at f32 element {index}: ULP distance {distance} > max_ulps {max_ulps}. \
                 Fix: reduce numeric drift or update the declared tolerance."
            ));
        }
    }
    Ok(())
}

fn compare_u16_ulps(gpu: &[u8], cpu: &[u8], max_ulps: u32) -> Result<(), String> {
    for (index, (gpu_chunk, cpu_chunk)) in gpu.chunks_exact(2).zip(cpu.chunks_exact(2)).enumerate()
    {
        let gpu_bits = u16::from_le_bytes([gpu_chunk[0], gpu_chunk[1]]);
        let cpu_bits = u16::from_le_bytes([cpu_chunk[0], cpu_chunk[1]]);
        let distance = f16_ulp_distance(gpu_bits, cpu_bits);
        if distance > max_ulps {
            return Err(format!(
                "approximate mismatch at 16-bit float element {index}: ULP distance {distance} > max_ulps {max_ulps}. \
                 Fix: reduce numeric drift or update the declared tolerance."
            ));
        }
    }
    Ok(())
}

fn f32_ulp_distance(left: u32, right: u32) -> u32 {
    let left_value = f32::from_bits(left);
    let right_value = f32::from_bits(right);
    if left == right {
        return 0;
    }
    if left_value == right_value {
        return 0;
    }
    if left_value.is_nan() && right_value.is_nan() {
        return 0;
    }
    f32_ordered(left).abs_diff(f32_ordered(right))
}

fn f32_ordered(bits: u32) -> u32 {
    if bits & 0x8000_0000 == 0 {
        bits | 0x8000_0000
    } else {
        !bits
    }
}

fn f16_ulp_distance(left: u16, right: u16) -> u32 {
    if left == right || (left & 0x7fff == 0 && right & 0x7fff == 0) {
        return 0;
    }
    if u16_float_is_nan(left) && u16_float_is_nan(right) {
        return 0;
    }
    u32::from(f16_ordered(left).abs_diff(f16_ordered(right)))
}

fn u16_float_is_nan(bits: u16) -> bool {
    bits & 0x7c00 == 0x7c00 && bits & 0x03ff != 0
}

fn f16_ordered(bits: u16) -> u16 {
    if bits & 0x8000 == 0 {
        bits | 0x8000
    } else {
        !bits
    }
}