use crate::verify::budget::{budget_for_op, Archetype, BudgetTracker, ReferenceBombDetected};
use vyre_reference::{dual::ReferenceFn, resolve_dual};
pub const DEFAULT_CASES: u64 = 1_000_000;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DualReferenceReport {
pub cases: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DualReferenceDivergence {
pub op_id: String,
pub case_index: u64,
pub input: Vec<u8>,
pub reference_a: Vec<u8>,
pub reference_b: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DualReferenceError {
Divergence(DualReferenceDivergence),
Bomb(ReferenceBombDetected),
}
#[inline]
pub fn run_default(op_id: &str) -> Result<DualReferenceReport, DualReferenceError> {
run(op_id, DEFAULT_CASES)
}
#[inline]
pub fn run(op_id: &str, cases: u64) -> Result<DualReferenceReport, DualReferenceError> {
let Some((reference_a, reference_b)) = resolve_dual(op_id) else {
return Err(DualReferenceError::Divergence(DualReferenceDivergence {
op_id: op_id.to_string(),
case_index: 0,
input: Vec::new(),
reference_a: b"Fix: missing reference_a for op_id".to_vec(),
reference_b: b"Fix: missing reference_b for op_id".to_vec(),
}));
};
run_with_references(op_id, reference_a, reference_b, cases)
}
#[inline]
pub fn run_with_references(
op_id: &str,
reference_a: ReferenceFn,
reference_b: ReferenceFn,
cases: u64,
) -> Result<DualReferenceReport, DualReferenceError> {
let archetype = infer_archetype(op_id);
let budget = budget_for_op(op_id, &archetype);
let mut tracker_a = BudgetTracker::new(budget, op_id);
let mut tracker_b = BudgetTracker::new(budget, op_id);
let mut rng = SplitMix64::new(seed_for(op_id));
for case_index in 0..cases {
let input = random_input(&mut rng);
tracker_a
.check_input(&input)
.map_err(DualReferenceError::Bomb)?;
tracker_b
.check_input(&input)
.map_err(DualReferenceError::Bomb)?;
let start_a = std::time::Instant::now();
let output_a = reference_a(&input);
let elapsed_a = start_a.elapsed();
tracker_a
.record_case(elapsed_a)
.map_err(DualReferenceError::Bomb)?;
let start_b = std::time::Instant::now();
let output_b = reference_b(&input);
let elapsed_b = start_b.elapsed();
tracker_b
.record_case(elapsed_b)
.map_err(DualReferenceError::Bomb)?;
if output_a != output_b {
return Err(DualReferenceError::Divergence(DualReferenceDivergence {
op_id: op_id.to_string(),
case_index,
input,
reference_a: output_a,
reference_b: output_b,
}));
}
}
Ok(DualReferenceReport { cases })
}
fn infer_archetype(op_id: &str) -> Archetype {
if op_id.contains("hash") && op_id.contains("u32") {
Archetype("hash-bytes-to-u32")
} else if op_id.contains("hash") && op_id.contains("u64") {
Archetype("hash-bytes-to-u64")
} else if op_id.contains("decode") {
Archetype("decode-bytes-to-bytes")
} else if op_id.contains("compress") {
Archetype("compression-bytes-to-bytes")
} else {
Archetype("unknown")
}
}
#[derive(Debug, Clone, Copy)]
struct SplitMix64 {
state: u64,
}
impl SplitMix64 {
fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next(&mut self) -> u64 {
self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut value = self.state;
value = (value ^ (value >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
value = (value ^ (value >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
value ^ (value >> 31)
}
}
fn random_input(rng: &mut SplitMix64) -> Vec<u8> {
let len = (rng.next() % 65) as usize;
let mut input = Vec::with_capacity(len);
while input.len() < len {
input.extend_from_slice(&rng.next().to_le_bytes());
}
input.truncate(len);
input
}
fn seed_for(op_id: &str) -> u64 {
let mut seed = 0xA076_1D64_78BD_642F_u64;
for byte in op_id.bytes() {
seed ^= u64::from(byte);
seed = seed.wrapping_mul(0xE703_7ED1_A0B4_28DB);
}
seed
}