use crate::verify::budget::{budget_for_op, Archetype, BudgetTracker};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Report {
pub total: u64,
pub matches: u64,
pub mismatches: Vec<Mismatch>,
pub first_mismatch_input: Option<Vec<u8>>,
pub explain: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Mismatch {
pub case_index: u64,
pub input: Vec<u8>,
pub reference_output: Vec<u8>,
pub kernel_output: Vec<u8>,
}
const DEFAULT_CASES: u64 = 10_000_000;
#[inline]
pub fn parity_10m(
op_id: &str,
reference: fn(&[u8]) -> Vec<u8>,
kernel: fn(&[u8]) -> Vec<u8>,
seed: u64,
) -> Report {
parity_10m_with_budget(op_id, reference, kernel, seed, None)
}
#[inline]
pub fn parity_10m_with_budget(
op_id: &str,
reference: fn(&[u8]) -> Vec<u8>,
kernel: fn(&[u8]) -> Vec<u8>,
seed: u64,
budget_override: Option<crate::verify::budget::ReferenceBudget>,
) -> Report {
let archetype = infer_archetype(op_id);
let budget = budget_override.unwrap_or_else(|| budget_for_op(op_id, &archetype));
let mut tracker = BudgetTracker::new(budget, op_id);
let mut rng = SplitMix64::new(seed);
let mut total = 0_u64;
let mut matches = 0_u64;
let mut mismatches = Vec::new();
let mut first_mismatch_input = None;
for case_index in 0..DEFAULT_CASES {
let input = random_input(&mut rng);
if let Err(bomb) = tracker.check_input(&input) {
return Report {
total,
matches,
mismatches,
first_mismatch_input,
explain: Some(bomb.to_string()),
};
}
let ref_start = std::time::Instant::now();
let reference_output = reference(&input);
let ref_elapsed = ref_start.elapsed();
if let Err(bomb) = tracker.record_case(ref_elapsed) {
return Report {
total,
matches,
mismatches,
first_mismatch_input,
explain: Some(bomb.to_string()),
};
}
let kernel_output = kernel(&input);
total += 1;
if reference_output == kernel_output {
matches += 1;
} else {
if first_mismatch_input.is_none() {
first_mismatch_input = Some(input.clone());
}
mismatches.push(Mismatch {
case_index,
input,
reference_output,
kernel_output,
});
}
}
Report {
total,
matches,
mismatches,
first_mismatch_input,
explain: None,
}
}
#[inline]
pub fn parity_10m_blocking_ci(op_id: &str) -> Result<(), String> {
let specs = crate::spec::op_registry::compiled_specs();
let Some(spec) = specs.into_iter().find(|s| s.id == op_id) else {
return Err(format!(
"Fix: op_id {} not found in compiled registry. \
Register the op before enabling parity_10m_blocking_ci.",
op_id
));
};
use crate::pipeline::backend::require_gpu;
use crate::pipeline::execution::execute_op;
let backend = require_gpu().expect("vyre-conform test needs a GPU adapter");
let budget = budget_for_op(op_id, &Archetype("binary-bitwise"));
let mut tracker = BudgetTracker::new(budget, op_id);
let mut rng = SplitMix64::new(seed_for(op_id));
for case_index in 0..DEFAULT_CASES {
let input = random_input(&mut rng);
tracker.check_input(&input).map_err(|b| b.to_string())?;
let start = std::time::Instant::now();
let result = execute_op(&backend, &spec, &input, 64);
let elapsed = start.elapsed();
tracker.record_case(elapsed).map_err(|b| b.to_string())?;
if let Err(msg) = result {
return Err(format!(
"Fix: parity failure at case {} for {}: {}",
case_index, op_id, msg
));
}
}
Ok(())
}
fn infer_archetype(op_id: &str) -> Archetype {
if op_id.contains("hash") && op_id.contains("u32") {
Archetype("hash-bytes-to-u32")
} else if op_id.contains("hash") && op_id.contains("u64") {
Archetype("hash-bytes-to-u64")
} else if op_id.contains("decode") {
Archetype("decode-bytes-to-bytes")
} else if op_id.contains("compress") {
Archetype("compression-bytes-to-bytes")
} else {
Archetype("unknown")
}
}
#[derive(Debug, Clone, Copy)]
struct SplitMix64 {
state: u64,
}
impl SplitMix64 {
fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next(&mut self) -> u64 {
self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut value = self.state;
value = (value ^ (value >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
value = (value ^ (value >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
value ^ (value >> 31)
}
}
fn random_input(rng: &mut SplitMix64) -> Vec<u8> {
let len = (rng.next() % 65) as usize;
let mut input = Vec::with_capacity(len);
while input.len() < len {
input.extend_from_slice(&rng.next().to_le_bytes());
}
input.truncate(len);
input
}
fn seed_for(op_id: &str) -> u64 {
let mut seed = 0xA076_1D64_78BD_642F_u64;
for byte in op_id.bytes() {
seed ^= u64::from(byte);
seed = seed.wrapping_mul(0xE703_7ED1_A0B4_28DB);
}
seed
}