use vyre::ir::Program;
use crate::proof::comparator::ComparatorKind;
use crate::spec::value::Value;
use crate::verify::harnesses::backend::backend_registry;
use vyre::Error;
use vyre_reference;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ReferenceDiffError {
NoGroundTruth,
}
impl core::fmt::Display for ReferenceDiffError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::NoGroundTruth => write!(
f,
"reference interpreter returned Pending. \
Fix: implement the missing reference variant before enabling this op in CI."
),
}
}
}
impl core::error::Error for ReferenceDiffError {}
#[derive(Debug, Clone)]
pub struct ReferenceDiffReport {
pub reference_status: ReferenceStatus,
pub backend_results: Vec<BackendDiffResult>,
pub minimized_input: Option<Vec<Value>>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ReferenceStatus {
Ok(Vec<Value>),
Pending,
Err(Error),
}
#[derive(Debug, Clone)]
pub struct BackendDiffResult {
pub backend_name: String,
pub result: Result<Vec<Value>, String>,
pub diverges_from_reference: bool,
}
#[inline]
pub fn reference_diff(
program: &Program,
inputs: &[Value],
) -> Result<ReferenceDiffReport, ReferenceDiffError> {
reference_diff_with_tolerance(program, inputs, None)
}
#[inline]
pub fn reference_diff_with_tolerance(
program: &Program,
inputs: &[Value],
tolerance_ulps: Option<u32>,
) -> Result<ReferenceDiffReport, ReferenceDiffError> {
let reference_inputs = Value::to_reference_values(inputs);
let reference_result =
vyre_reference::run(program, &reference_inputs).map(Value::from_reference_values);
let reference_status = match reference_result {
Ok(values) => ReferenceStatus::Ok(values),
Err(e) if is_pending_reference_error(&e) => return Err(ReferenceDiffError::NoGroundTruth),
Err(e) => ReferenceStatus::Err(e),
};
let reference_values = match &reference_status {
ReferenceStatus::Ok(v) => Some(v.clone()),
_ => None,
};
let mut backend_results = Vec::new();
for backend in backend_registry() {
let result = backend.run(program, inputs);
let diverges = reference_values
.as_ref()
.map(|expected| result_diverges_from_expected(&result, expected, tolerance_ulps))
.unwrap_or(false);
backend_results.push(BackendDiffResult {
backend_name: backend.name().to_string(),
result,
diverges_from_reference: diverges,
});
}
if reference_values.is_none() {
mark_backend_equivalence_divergences(&mut backend_results, tolerance_ulps);
}
let minimized_input = None;
Ok(ReferenceDiffReport {
reference_status,
backend_results,
minimized_input,
})
}
fn is_pending_reference_error(error: &Error) -> bool {
let Error::Interp { message } = error else {
return false;
};
message.contains("unsupported IR") || message.contains("pending upstream float variants")
}
fn result_diverges_from_expected(
result: &Result<Vec<Value>, String>,
expected: &[Value],
tolerance_ulps: Option<u32>,
) -> bool {
result
.as_ref()
.map(|actual| !values_equal(actual, expected, tolerance_ulps))
.unwrap_or(true)
}
fn mark_backend_equivalence_divergences(
backend_results: &mut [BackendDiffResult],
tolerance_ulps: Option<u32>,
) {
if backend_results.len() < 2 {
for result in backend_results {
result.diverges_from_reference = true;
}
return;
}
for i in 0..backend_results.len() {
let diverges = backend_results.iter().enumerate().any(|(j, other)| {
i != j
&& !backend_results_equivalent(
&backend_results[i].result,
&other.result,
tolerance_ulps,
)
});
backend_results[i].diverges_from_reference = diverges;
}
}
fn backend_results_equivalent(
left: &Result<Vec<Value>, String>,
right: &Result<Vec<Value>, String>,
tolerance_ulps: Option<u32>,
) -> bool {
match (left, right) {
(Ok(left), Ok(right)) => values_equal(left, right, tolerance_ulps),
(Err(left), Err(right)) => left == right,
_ => false,
}
}
fn values_equal(left: &[Value], right: &[Value], tolerance_ulps: Option<u32>) -> bool {
if tolerance_ulps.is_none() {
return left == right;
}
let left = values_to_bytes(left);
let right = values_to_bytes(right);
ComparatorKind::Approximate {
epsilon: tolerance_ulps.expect("checked above"),
}
.compare(&left, &right)
.is_ok()
}
fn values_to_bytes(values: &[Value]) -> Vec<u8> {
values.iter().flat_map(Value::to_bytes).collect()
}
#[cfg(test)]
mod tests {
use super::{
mark_backend_equivalence_divergences, reference_diff, values_equal, BackendDiffResult,
};
use crate::spec::value::Value;
use crate::verify::harnesses::backend::{register_backend, HarnessBackend};
use vyre::ir::Program;
#[test]
fn reference_diff_populates_backend_results_from_registry() {
let program = Program::new(vec![], [1, 1, 1], vec![vyre::ir::Node::Return]);
let report = reference_diff(&program, &[])
.expect("reference diff should succeed for supported program");
assert!(
report
.backend_results
.iter()
.any(|r| r.backend_name == "reference"),
"backend_results must contain the reference backend"
);
}
#[test]
fn reference_diff_detects_backend_divergence() {
struct DivergentBackend;
impl HarnessBackend for DivergentBackend {
fn name(&self) -> &str {
"divergent-mock"
}
fn run_with_byte_length(
&self,
_program: &Program,
_inputs: &[Value],
) -> Result<(Vec<Value>, usize), String> {
Ok((vec![Value::U32(0xDEAD_BEEF)], 4))
}
}
register_backend(Box::leak(Box::new(DivergentBackend)));
let program = Program::new(
vec![vyre::ir::BufferDecl::read_write(
"out",
0,
vyre::ir::DataType::U32,
)],
[1, 1, 1],
vec![vyre::ir::Node::store(
"out",
vyre::ir::Expr::u32(0),
vyre::ir::Expr::u32(42),
)],
);
let inputs = vec![Value::U32(0)];
let report = reference_diff(&program, &inputs).expect("reference diff should succeed");
let divergent = report
.backend_results
.iter()
.find(|r| r.backend_name == "divergent-mock")
.expect("divergent-mock must be in results");
assert!(
divergent.diverges_from_reference,
"divergent backend must be flagged"
);
}
#[test]
fn pending_reference_marks_backend_equivalence_divergence() {
let mut results = vec![
BackendDiffResult {
backend_name: "a".to_string(),
result: Ok(vec![Value::U32(1)]),
diverges_from_reference: false,
},
BackendDiffResult {
backend_name: "b".to_string(),
result: Ok(vec![Value::U32(2)]),
diverges_from_reference: false,
},
];
mark_backend_equivalence_divergences(&mut results, None);
assert!(
results.iter().all(|result| result.diverges_from_reference),
"H7 regression: backend/backend mismatch must not be hidden while reference is pending"
);
}
#[test]
fn approximate_tolerance_accepts_one_ulp_difference() {
let left = vec![Value::Bytes(
1.0_f32.to_bits().wrapping_add(1).to_le_bytes().to_vec(),
)];
let right = vec![Value::Bytes(1.0_f32.to_le_bytes().to_vec())];
assert!(
values_equal(&left, &right, Some(1)),
"H7 regression: approximate track must use ULP tolerance"
);
assert!(
!values_equal(&left, &right, None),
"strict track must remain byte-exact"
);
}
}