use crate::spec::types::ParityFailure;
use std::fs;
use std::io::{self, Write};
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
#[cfg(loom)]
use loom::sync::Mutex as LoomMutex;
#[cfg(loom)]
static LOOM_STORE: LoomMutex<Vec<ParityFailure>> = LoomMutex::new(Vec::new());
#[inline]
pub fn load(op_id: &str) -> Vec<(String, Vec<u8>)> {
load_all_versions(op_id)
}
fn load_from_dir(dir: &std::path::Path, tag: &str) -> Vec<(String, Vec<u8>)> {
let Ok(entries) = fs::read_dir(dir) else {
return Vec::new();
};
let mut out = Vec::new();
for entry in entries.flatten() {
let path = entry.path();
let label = match path.file_stem().and_then(|stem| stem.to_str()) {
Some(stem) => stem.to_string(),
None => "regression".to_string(),
};
match path.extension().and_then(|ext| ext.to_str()) {
Some("json") => {
if let Ok(text) = fs::read_to_string(&path) {
if let Ok(failure) = serde_json::from_str::<PersistedFailure>(&text) {
out.push((format!("regression:{tag}:{label}"), failure.input));
}
}
}
Some("hex") => {
if let Ok(text) = fs::read_to_string(&path) {
if let Ok(bytes) = decode_hex(text.trim()) {
out.push((format!("regression:{tag}:{label}"), bytes));
}
}
}
Some("bin") => {
if let Ok(bytes) = fs::read(&path) {
out.push((format!("regression:{tag}:{label}"), bytes));
}
}
_ => continue,
}
}
out.sort_by(|a, b| a.0.cmp(&b.0));
out
}
#[inline]
pub fn save(failure: &ParityFailure) -> io::Result<PathBuf> {
#[cfg(loom)]
{
let mut store = LOOM_STORE.lock().unwrap();
store.push(failure.clone());
return Ok(PathBuf::from("loom-mem"));
}
#[cfg(not(loom))]
{
let dir = versioned_regression_dir(&failure.op_id, failure.spec_version);
fs::create_dir_all(&dir)?;
let bytes = serialize_failure(failure)?;
let name = format!("{}.json", sha256_hex(&failure.input));
let path = dir.join(name);
atomic_write_new(&path, &bytes)?;
Ok(path)
}
}
#[inline]
pub fn save_binary(failure: &ParityFailure) -> io::Result<PathBuf> {
let dir = regression_dir(&failure.op_id);
fs::create_dir_all(&dir)?;
let name = format!("{}.bin", sha256_hex(&failure.input));
let path = dir.join(name);
atomic_write_new(&path, &failure.input)?;
Ok(path)
}
#[inline]
pub fn load_failures_versioned(op_id: &str, version: u32) -> Vec<ParityFailure> {
let dir = versioned_regression_dir(op_id, version);
let Ok(entries) = fs::read_dir(dir) else {
return Vec::new();
};
let mut failures = Vec::new();
for entry in entries.flatten() {
let path = entry.path();
if path.extension().and_then(|ext| ext.to_str()) != Some("json") {
continue;
}
if let Ok(text) = fs::read_to_string(path) {
if let Ok(persisted) = serde_json::from_str::<PersistedFailure>(&text) {
failures.push(persisted.into_failure());
}
}
}
failures.sort_by(|a, b| a.input_label.cmp(&b.input_label));
failures
}
#[inline]
pub fn load_versioned(op_id: &str, version: u32) -> Vec<(String, Vec<u8>)> {
let dir = versioned_regression_dir(op_id, version);
load_from_dir(&dir, &format!("v{version}"))
}
#[inline]
pub fn load_all_versions(op_id: &str) -> Vec<(String, Vec<u8>)> {
#[cfg(loom)]
{
let store = LOOM_STORE.lock().unwrap();
let mut results: Vec<_> = store
.iter()
.filter(|f| f.op_id == op_id)
.map(|f| (f.input_label.clone(), f.input.clone()))
.collect();
results.sort_by(|a, b| a.0.cmp(&b.0));
results
}
#[cfg(not(loom))]
{
let mut results = Vec::new();
let legacy_dir = regression_dir(op_id);
results.extend(load_from_dir(&legacy_dir, "legacy"));
if let Ok(entries) = fs::read_dir(&legacy_dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
if name.starts_with('v') {
results.extend(load_from_dir(&path, name));
}
}
}
}
}
results.sort_by(|a, b| a.0.cmp(&b.0));
results
}
}
#[inline]
pub fn load_all_versions_from_dir(dir: &std::path::Path) -> Vec<(String, Vec<u8>)> {
let mut results = Vec::new();
results.extend(load_from_dir(dir, "legacy"));
if let Ok(entries) = fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
if name.starts_with('v') {
results.extend(load_from_dir(&path, name));
}
}
}
}
}
results.sort_by(|a, b| a.0.cmp(&b.0));
results
}
#[inline]
pub fn load_failures_from_dir(dir: &std::path::Path) -> Vec<ParityFailure> {
let Ok(entries) = fs::read_dir(dir) else {
return Vec::new();
};
let mut failures = Vec::new();
for entry in entries.flatten() {
let path = entry.path();
if path.extension().and_then(|ext| ext.to_str()) != Some("json") {
continue;
}
if let Ok(text) = fs::read_to_string(path) {
if let Ok(persisted) = serde_json::from_str::<PersistedFailure>(&text) {
failures.push(persisted.into_failure());
}
}
}
failures.sort_by(|a, b| a.input_label.cmp(&b.input_label));
failures
}
pub(super) fn regression_dir(op_id: &str) -> PathBuf {
let root = std::env!("CARGO_MANIFEST_DIR");
PathBuf::from(root)
.join("regressions")
.join(sanitize(op_id))
}
fn versioned_regression_dir(op_id: &str, version: u32) -> PathBuf {
regression_dir(op_id).join(format!("v{version}"))
}
pub(super) fn sanitize(value: &str) -> String {
let mut out = String::with_capacity(value.len());
for byte in value.bytes() {
if byte.is_ascii_alphanumeric() || byte == b'-' || byte == b'_' {
out.push(byte as char);
} else {
out.push('%');
out.push(nibble(byte >> 4).to_ascii_uppercase());
out.push(nibble(byte & 0x0F).to_ascii_uppercase());
}
}
out
}
#[cfg(test)]
pub(super) fn encode_hex(bytes: &[u8]) -> String {
let mut out = String::with_capacity(bytes.len() * 2);
for byte in bytes {
out.push(nibble(byte >> 4));
out.push(nibble(byte & 0x0F));
}
out
}
pub(super) fn decode_hex(text: &str) -> Result<Vec<u8>, String> {
if text.len() % 2 != 0 {
return Err("hex input has odd length. Fix: use two hex chars per byte.".to_string());
}
let mut out = Vec::with_capacity(text.len() / 2);
for chunk in text.as_bytes().chunks(2) {
let high = from_hex(chunk[0])?;
let low = from_hex(chunk[1])?;
out.push((high << 4) | low);
}
Ok(out)
}
pub(super) fn nibble(value: u8) -> char {
b"0123456789abcdef"[value as usize] as char
}
pub(super) fn from_hex(value: u8) -> Result<u8, String> {
match value {
b'0'..=b'9' => Ok(value - b'0'),
b'a'..=b'f' => Ok(value - b'a' + 10),
b'A'..=b'F' => Ok(value - b'A' + 10),
_ => Err("invalid hex byte. Fix: use characters 0-9, a-f, or A-F.".to_string()),
}
}
fn serialize_failure(failure: &ParityFailure) -> io::Result<Vec<u8>> {
serde_json::to_vec_pretty(failure).map_err(|err| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("could not serialize regression failure: {err}. Fix: persist JSON-compatible ParityFailure fields."),
)
})
}
fn atomic_write_new(path: &std::path::Path, bytes: &[u8]) -> io::Result<()> {
let tmp_path = temp_path(path);
let mut tmp = fs::OpenOptions::new()
.write(true)
.create_new(true)
.open(&tmp_path)?;
if let Err(err) = write_and_commit(&mut tmp, &tmp_path, path, bytes) {
let _ = fs::remove_file(&tmp_path);
return Err(err);
}
Ok(())
}
fn write_and_commit(
tmp: &mut fs::File,
tmp_path: &std::path::Path,
path: &std::path::Path,
bytes: &[u8],
) -> io::Result<()> {
tmp.write_all(bytes)?;
tmp.sync_all()?;
match fs::hard_link(tmp_path, path) {
Ok(()) => fs::remove_file(tmp_path),
Err(err) if err.kind() == io::ErrorKind::AlreadyExists => {
let existing = fs::read(path)?;
fs::remove_file(tmp_path)?;
if existing == bytes {
Ok(())
} else {
Err(io::Error::new(
io::ErrorKind::AlreadyExists,
format!(
"regression path already exists with different content: {}. Fix: investigate hash collision or corrupt regression file.",
path.display()
),
))
}
}
Err(err) => Err(err),
}
}
fn temp_path(path: &std::path::Path) -> PathBuf {
let pid = std::process::id();
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |duration| duration.as_nanos());
let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
let file_name = path
.file_name()
.and_then(|name| name.to_str())
.unwrap_or("regression");
path.with_file_name(format!("{file_name}.tmp.{pid}.{nanos}.{counter}"))
}
#[derive(serde::Deserialize)]
struct PersistedFailure {
op_id: String,
generator: String,
input_label: String,
input: Vec<u8>,
gpu_output: Vec<u8>,
cpu_output: Vec<u8>,
message: String,
spec_version: u32,
workgroup_size: u32,
}
impl PersistedFailure {
fn into_failure(self) -> ParityFailure {
ParityFailure {
op_id: self.op_id,
generator: self.generator,
input_label: self.input_label,
input: self.input,
gpu_output: self.gpu_output,
cpu_output: self.cpu_output,
message: self.message,
spec_version: self.spec_version,
workgroup_size: self.workgroup_size,
}
}
}
mod hex;
use hex::*;
#[cfg(test)]
mod tests;