use crate::{generate::generators, OpSpec};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::fs;
use std::io::{self, Write};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
static GOLDEN_TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Golden {
pub op_id: String,
pub spec_version: u32,
pub input: Vec<u8>,
pub output: Vec<u8>,
}
#[derive(Debug, Clone, Default)]
pub struct GoldenReport {
pub tested: usize,
pub passed: usize,
pub failures: Vec<GoldenMismatch>,
}
#[derive(Debug, Clone)]
pub struct GoldenMismatch {
pub op_id: String,
pub spec_version: u32,
pub input_hash: String,
pub expected_output: Vec<u8>,
pub actual_output: Vec<u8>,
}
#[inline]
pub fn load_goldens(root: &Path) -> io::Result<Vec<Golden>> {
let mut out = Vec::new();
if !root.exists() {
return Ok(out);
}
for entry in walkdir::WalkDir::new(root) {
let entry = entry.map_err(io::Error::other)?;
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) != Some("json") {
continue;
}
let text = fs::read_to_string(path)?;
let golden: Golden = serde_json::from_str(&text).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("invalid golden JSON at {}: {}", path.display(), e),
)
})?;
out.push(golden);
}
out.sort_by(|a, b| {
a.op_id
.cmp(&b.op_id)
.then(a.spec_version.cmp(&b.spec_version))
});
Ok(out)
}
#[inline]
pub fn replay_goldens(goldens: &[Golden], specs: &[OpSpec]) -> GoldenReport {
let mut report = GoldenReport {
tested: goldens.len(),
..Default::default()
};
for golden in goldens {
let Some(spec) = specs.iter().find(|s| s.id == golden.op_id) else {
report.failures.push(GoldenMismatch {
op_id: golden.op_id.clone(),
spec_version: golden.spec_version,
input_hash: sha256_hex(&golden.input),
expected_output: golden.output.clone(),
actual_output: Vec::new(),
});
continue;
};
let actual = (spec.cpu_fn)(&golden.input);
if actual == golden.output {
report.passed += 1;
} else {
report.failures.push(GoldenMismatch {
op_id: golden.op_id.clone(),
spec_version: golden.spec_version,
input_hash: sha256_hex(&golden.input),
expected_output: golden.output.clone(),
actual_output: actual,
});
}
}
report
}
#[inline]
pub fn freeze_goldens(specs: &[OpSpec], out_dir: &Path) -> io::Result<usize> {
let mut frozen = 0;
for spec in specs {
let seed = fnv1a_u64(spec.id.as_bytes());
let inputs = generate_50_inputs(spec, seed);
let op_dir = out_dir
.join(sanitize(spec.id))
.join(format!("v{}", spec.version));
fs::create_dir_all(&op_dir)?;
for input in inputs {
let hash = sha256_hex(&input);
let path = op_dir.join(format!("{hash}.json"));
if path.exists() {
continue;
}
let output = (spec.cpu_fn)(&input);
let golden = Golden {
op_id: spec.id.to_string(),
spec_version: spec.version,
input,
output,
};
let json = serde_json::to_string_pretty(&golden).map_err(|e| {
io::Error::new(io::ErrorKind::InvalidData, format!("JSON serialize: {e}"))
})?;
atomic_write_new(&path, json.as_bytes())?;
frozen += 1;
}
}
Ok(frozen)
}
fn atomic_write_new(path: &Path, bytes: &[u8]) -> io::Result<()> {
let tmp = temp_path(path);
let mut file = fs::OpenOptions::new()
.write(true)
.create_new(true)
.open(&tmp)?;
if let Err(err) = file
.write_all(bytes)
.and_then(|()| file.sync_all())
.and_then(|()| fs::hard_link(&tmp, path))
.and_then(|()| fs::remove_file(&tmp))
{
let _ = fs::remove_file(&tmp);
if err.kind() == io::ErrorKind::AlreadyExists {
return Ok(());
}
return Err(err);
}
Ok(())
}
fn temp_path(path: &Path) -> PathBuf {
let pid = std::process::id();
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |duration| duration.as_nanos());
let counter = GOLDEN_TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
let file_name = path
.file_name()
.and_then(|name| name.to_str())
.unwrap_or("golden");
path.with_file_name(format!("{file_name}.tmp.{pid}.{nanos}.{counter}"))
}
fn generate_50_inputs(spec: &OpSpec, seed: u64) -> Vec<Vec<u8>> {
let gens = generators::default_generators();
let mut seen = HashSet::new();
let mut out = Vec::new();
for gen in &gens {
if !gen.handles(&spec.signature) {
continue;
}
for (_label, bytes) in gen.generate_for_op(spec.id, &spec.signature, seed) {
if seen.insert(bytes.clone()) {
out.push(bytes);
if out.len() >= 50 {
return out;
}
}
}
}
let mut rng = XorShift64::new(seed);
while out.len() < 50 {
let bytes = random_input_for_signature(&mut rng, &spec.signature);
if seen.insert(bytes.clone()) {
out.push(bytes);
}
}
out
}
fn random_input_for_signature(
rng: &mut XorShift64,
signature: &crate::spec::types::OpSignature,
) -> Vec<u8> {
use crate::spec::types::DataType;
let mut bytes = Vec::new();
for ty in &signature.inputs {
match ty {
DataType::U32 => bytes.extend_from_slice(&rng.next_u32().to_le_bytes()),
DataType::I32 => bytes.extend_from_slice(&(rng.next_u32() as i32).to_le_bytes()),
DataType::U64 => bytes.extend_from_slice(&rng.next_u64().to_le_bytes()),
DataType::Vec2U32 => {
bytes.extend_from_slice(&rng.next_u32().to_le_bytes());
bytes.extend_from_slice(&rng.next_u32().to_le_bytes());
}
DataType::Vec4U32 => {
for _ in 0..4 {
bytes.extend_from_slice(&rng.next_u32().to_le_bytes());
}
}
DataType::Bytes | DataType::Array { .. } => {
let len = (rng.next_u32() as usize) % 4097;
for _ in 0..len {
bytes.push((rng.next_u32() & 0xFF) as u8);
}
}
_ => {
let len = (rng.next_u32() as usize) % 64;
for _ in 0..len {
bytes.push((rng.next_u32() & 0xFF) as u8);
}
}
}
}
bytes
}
mod util;
use util::*;
#[cfg(test)]
mod tests;