vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
use crate::spec::types::ParityFailure;
use std::fs;
use std::io::{self, Write};
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};

static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);

#[cfg(loom)]
use loom::sync::Mutex as LoomMutex;

#[cfg(loom)]
static LOOM_STORE: LoomMutex<Vec<ParityFailure>> = LoomMutex::new(Vec::new());

/// Load persisted failing inputs for an operation (all versions).
#[inline]
pub fn load(op_id: &str) -> Vec<(String, Vec<u8>)> {
    load_all_versions(op_id)
}

/// Load regression files from a directory, tagging each with `tag`.
fn load_from_dir(dir: &std::path::Path, tag: &str) -> Vec<(String, Vec<u8>)> {
    let Ok(entries) = fs::read_dir(dir) else {
        return Vec::new();
    };
    let mut out = Vec::new();
    for entry in entries.flatten() {
        let path = entry.path();
        let label = match path.file_stem().and_then(|stem| stem.to_str()) {
            Some(stem) => stem.to_string(),
            None => "regression".to_string(),
        };
        match path.extension().and_then(|ext| ext.to_str()) {
            Some("json") => {
                if let Ok(text) = fs::read_to_string(&path) {
                    if let Ok(failure) = serde_json::from_str::<PersistedFailure>(&text) {
                        out.push((format!("regression:{tag}:{label}"), failure.input));
                    }
                }
            }
            Some("hex") => {
                if let Ok(text) = fs::read_to_string(&path) {
                    if let Ok(bytes) = decode_hex(text.trim()) {
                        out.push((format!("regression:{tag}:{label}"), bytes));
                    }
                }
            }
            Some("bin") => {
                if let Ok(bytes) = fs::read(&path) {
                    out.push((format!("regression:{tag}:{label}"), bytes));
                }
            }
            _ => continue,
        }
    }
    out.sort_by(|a, b| a.0.cmp(&b.0));
    out
}

/// Persist a failing input so future runs replay it before generated cases.
///
/// Regressions are stored in version-specific directories:
/// `regressions/<op_id>/v<version>/<sha256>.json`
#[inline]
pub fn save(failure: &ParityFailure) -> io::Result<PathBuf> {
    #[cfg(loom)]
    {
        let mut store = LOOM_STORE.lock().unwrap();
        store.push(failure.clone());
        return Ok(PathBuf::from("loom-mem"));
    }
    #[cfg(not(loom))]
    {
        let dir = versioned_regression_dir(&failure.op_id, failure.spec_version);
        fs::create_dir_all(&dir)?;
        let bytes = serialize_failure(failure)?;
        let name = format!("{}.json", sha256_hex(&failure.input));
        let path = dir.join(name);
        atomic_write_new(&path, &bytes)?;
        Ok(path)
    }
}

/// Persist a failing streaming input as raw bytes.
///
/// Streaming regressions use `regressions/<op_id>/<hash>.bin` so large inputs
/// can be written without hex expansion while remaining deterministic.
#[inline]
pub fn save_binary(failure: &ParityFailure) -> io::Result<PathBuf> {
    let dir = regression_dir(&failure.op_id);
    fs::create_dir_all(&dir)?;
    let name = format!("{}.bin", sha256_hex(&failure.input));
    let path = dir.join(name);
    atomic_write_new(&path, &failure.input)?;
    Ok(path)
}

/// Load full persisted parity failures for a specific operation and version.
#[inline]
pub fn load_failures_versioned(op_id: &str, version: u32) -> Vec<ParityFailure> {
    let dir = versioned_regression_dir(op_id, version);
    let Ok(entries) = fs::read_dir(dir) else {
        return Vec::new();
    };
    let mut failures = Vec::new();
    for entry in entries.flatten() {
        let path = entry.path();
        if path.extension().and_then(|ext| ext.to_str()) != Some("json") {
            continue;
        }
        if let Ok(text) = fs::read_to_string(path) {
            if let Ok(persisted) = serde_json::from_str::<PersistedFailure>(&text) {
                failures.push(persisted.into_failure());
            }
        }
    }
    failures.sort_by(|a, b| a.input_label.cmp(&b.input_label));
    failures
}

/// Load regressions for a specific version of an operation.
///
/// Returns regressions tagged with the version they were recorded under.
#[inline]
pub fn load_versioned(op_id: &str, version: u32) -> Vec<(String, Vec<u8>)> {
    let dir = versioned_regression_dir(op_id, version);
    load_from_dir(&dir, &format!("v{version}"))
}

/// Load regressions from all versions of an operation, including
/// legacy unversioned regressions and all versioned subdirectories.
///
/// This is the migration path: when version increments, regressions from
/// prior versions are still loaded and re-tested (they should still pass
/// under the new version if the change was backwards-compatible).
#[inline]
pub fn load_all_versions(op_id: &str) -> Vec<(String, Vec<u8>)> {
    #[cfg(loom)]
    {
        let store = LOOM_STORE.lock().unwrap();
        let mut results: Vec<_> = store
            .iter()
            .filter(|f| f.op_id == op_id)
            .map(|f| (f.input_label.clone(), f.input.clone()))
            .collect();
        results.sort_by(|a, b| a.0.cmp(&b.0));
        results
    }
    #[cfg(not(loom))]
    {
        let mut results = Vec::new();
        // Legacy unversioned regressions (before version migration was added).
        let legacy_dir = regression_dir(op_id);
        results.extend(load_from_dir(&legacy_dir, "legacy"));
        // Versioned regressions.
        if let Ok(entries) = fs::read_dir(&legacy_dir) {
            for entry in entries.flatten() {
                let path = entry.path();
                if path.is_dir() {
                    if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
                        if name.starts_with('v') {
                            results.extend(load_from_dir(&path, name));
                        }
                    }
                }
            }
        }
        results.sort_by(|a, b| a.0.cmp(&b.0));
        results
    }
}

/// Fuzz entry point: load regressions from an arbitrary directory tree.
#[inline]
pub fn load_all_versions_from_dir(dir: &std::path::Path) -> Vec<(String, Vec<u8>)> {
    let mut results = Vec::new();
    results.extend(load_from_dir(dir, "legacy"));
    if let Ok(entries) = fs::read_dir(dir) {
        for entry in entries.flatten() {
            let path = entry.path();
            if path.is_dir() {
                if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
                    if name.starts_with('v') {
                        results.extend(load_from_dir(&path, name));
                    }
                }
            }
        }
    }
    results.sort_by(|a, b| a.0.cmp(&b.0));
    results
}

/// Fuzz entry point: load versioned failures from an arbitrary directory.
#[inline]
pub fn load_failures_from_dir(dir: &std::path::Path) -> Vec<ParityFailure> {
    let Ok(entries) = fs::read_dir(dir) else {
        return Vec::new();
    };
    let mut failures = Vec::new();
    for entry in entries.flatten() {
        let path = entry.path();
        if path.extension().and_then(|ext| ext.to_str()) != Some("json") {
            continue;
        }
        if let Ok(text) = fs::read_to_string(path) {
            if let Ok(persisted) = serde_json::from_str::<PersistedFailure>(&text) {
                failures.push(persisted.into_failure());
            }
        }
    }
    failures.sort_by(|a, b| a.input_label.cmp(&b.input_label));
    failures
}

pub(super) fn regression_dir(op_id: &str) -> PathBuf {
    let root = std::env!("CARGO_MANIFEST_DIR");
    PathBuf::from(root)
        .join("regressions")
        .join(sanitize(op_id))
}

fn versioned_regression_dir(op_id: &str, version: u32) -> PathBuf {
    regression_dir(op_id).join(format!("v{version}"))
}

pub(super) fn sanitize(value: &str) -> String {
    let mut out = String::with_capacity(value.len());
    for byte in value.bytes() {
        if byte.is_ascii_alphanumeric() || byte == b'-' || byte == b'_' {
            out.push(byte as char);
        } else {
            out.push('%');
            out.push(nibble(byte >> 4).to_ascii_uppercase());
            out.push(nibble(byte & 0x0F).to_ascii_uppercase());
        }
    }
    out
}

#[cfg(test)]
pub(super) fn encode_hex(bytes: &[u8]) -> String {
    let mut out = String::with_capacity(bytes.len() * 2);
    for byte in bytes {
        out.push(nibble(byte >> 4));
        out.push(nibble(byte & 0x0F));
    }
    out
}

pub(super) fn decode_hex(text: &str) -> Result<Vec<u8>, String> {
    if text.len() % 2 != 0 {
        return Err("hex input has odd length. Fix: use two hex chars per byte.".to_string());
    }
    let mut out = Vec::with_capacity(text.len() / 2);
    for chunk in text.as_bytes().chunks(2) {
        let high = from_hex(chunk[0])?;
        let low = from_hex(chunk[1])?;
        out.push((high << 4) | low);
    }
    Ok(out)
}

pub(super) fn nibble(value: u8) -> char {
    b"0123456789abcdef"[value as usize] as char
}

pub(super) fn from_hex(value: u8) -> Result<u8, String> {
    match value {
        b'0'..=b'9' => Ok(value - b'0'),
        b'a'..=b'f' => Ok(value - b'a' + 10),
        b'A'..=b'F' => Ok(value - b'A' + 10),
        _ => Err("invalid hex byte. Fix: use characters 0-9, a-f, or A-F.".to_string()),
    }
}

fn serialize_failure(failure: &ParityFailure) -> io::Result<Vec<u8>> {
    serde_json::to_vec_pretty(failure).map_err(|err| {
        io::Error::new(
            io::ErrorKind::InvalidData,
            format!("could not serialize regression failure: {err}. Fix: persist JSON-compatible ParityFailure fields."),
        )
    })
}

fn atomic_write_new(path: &std::path::Path, bytes: &[u8]) -> io::Result<()> {
    let tmp_path = temp_path(path);
    let mut tmp = fs::OpenOptions::new()
        .write(true)
        .create_new(true)
        .open(&tmp_path)?;
    if let Err(err) = write_and_commit(&mut tmp, &tmp_path, path, bytes) {
        let _ = fs::remove_file(&tmp_path);
        return Err(err);
    }
    Ok(())
}

fn write_and_commit(
    tmp: &mut fs::File,
    tmp_path: &std::path::Path,
    path: &std::path::Path,
    bytes: &[u8],
) -> io::Result<()> {
    tmp.write_all(bytes)?;
    tmp.sync_all()?;
    match fs::hard_link(tmp_path, path) {
        Ok(()) => fs::remove_file(tmp_path),
        Err(err) if err.kind() == io::ErrorKind::AlreadyExists => {
            let existing = fs::read(path)?;
            fs::remove_file(tmp_path)?;
            if existing == bytes {
                Ok(())
            } else {
                Err(io::Error::new(
                    io::ErrorKind::AlreadyExists,
                    format!(
                        "regression path already exists with different content: {}. Fix: investigate hash collision or corrupt regression file.",
                        path.display()
                    ),
                ))
            }
        }
        Err(err) => Err(err),
    }
}

fn temp_path(path: &std::path::Path) -> PathBuf {
    let pid = std::process::id();
    let nanos = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .map_or(0, |duration| duration.as_nanos());
    let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
    let file_name = path
        .file_name()
        .and_then(|name| name.to_str())
        .unwrap_or("regression");
    path.with_file_name(format!("{file_name}.tmp.{pid}.{nanos}.{counter}"))
}

#[derive(serde::Deserialize)]
struct PersistedFailure {
    op_id: String,
    generator: String,
    input_label: String,
    input: Vec<u8>,
    gpu_output: Vec<u8>,
    cpu_output: Vec<u8>,
    message: String,
    spec_version: u32,
    workgroup_size: u32,
}

impl PersistedFailure {
    fn into_failure(self) -> ParityFailure {
        ParityFailure {
            op_id: self.op_id,
            generator: self.generator,
            input_label: self.input_label,
            input: self.input,
            gpu_output: self.gpu_output,
            cpu_output: self.cpu_output,
            message: self.message,
            spec_version: self.spec_version,
            workgroup_size: self.workgroup_size,
        }
    }
}

mod hex;
use hex::*;

#[cfg(test)]
mod tests;