use crate::pipeline::certify::Violation;
use crate::proof::oracles::{Oracle, VerifyCtx, VerifyResult};
use crate::spec::types::{ComparatorKind, OpSpec, OracleKind, Strictness};
pub struct BitIdenticalOracle;
pub const REGISTERED: BitIdenticalOracle = BitIdenticalOracle;
impl Oracle for BitIdenticalOracle {
#[inline]
fn id(&self) -> &'static str {
"bit_identical"
}
#[inline]
fn kind(&self) -> OracleKind {
OracleKind::CpuReference
}
#[inline]
fn applicable_to(&self, op: &OpSpec) -> bool {
matches!(op.strictness, Strictness::Strict)
&& matches!(op.comparator, ComparatorKind::ExactMatch)
}
#[inline]
fn verify(&self, ctx: &VerifyCtx<'_>) -> VerifyResult {
if ctx.backend_output == ctx.reference_output {
return Ok(());
}
let mismatch = first_mismatch(ctx.backend_output, ctx.reference_output);
Err(Violation::new(
ctx.op.id.to_string(),
"bit_identical_cpu_reference".to_string(),
ctx.backend_id.to_string(),
ctx.reference_output.to_vec(),
ctx.backend_output.to_vec(),
format!(
"Fix: make backend output byte-identical to the CPU reference for {}; first mismatch at byte {}.",
ctx.op.id, mismatch
),
))
}
}
fn first_mismatch(left: &[u8], right: &[u8]) -> usize {
left.iter()
.zip(right.iter())
.position(|(a, b)| a != b)
.unwrap_or_else(|| left.len().min(right.len()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::spec::types::{DataType, OpSignature};
use vyre_spec::Category;
fn cpu(input: &[u8]) -> Vec<u8> {
input.to_vec()
}
fn wgsl() -> String {
String::new()
}
fn strict_exact_spec() -> OpSpec {
OpSpec::builder("test.oracle.bit_identical")
.signature(OpSignature {
inputs: vec![DataType::U32],
output: DataType::U32,
})
.cpu_fn(cpu)
.wgsl_fn(wgsl)
.category(Category::A {
composition_of: vec!["test.oracle.bit_identical"],
})
.laws(vec![])
.strictness(Strictness::Strict)
.version(1)
.build()
.expect("Fix: test spec must build")
}
#[test]
fn accepts_identical_output() {
let spec = strict_exact_spec();
let ctx = VerifyCtx {
op: &spec,
backend_id: "test",
input: &[1, 2, 3, 4],
reference_output: &[7, 8, 9],
backend_output: &[7, 8, 9],
};
assert!(REGISTERED.verify(&ctx).is_ok());
}
#[test]
fn rejects_single_byte_drift() {
let spec = strict_exact_spec();
let ctx = VerifyCtx {
op: &spec,
backend_id: "test",
input: &[1, 2, 3, 4],
reference_output: &[7, 8, 9],
backend_output: &[7, 0, 9],
};
let err = REGISTERED.verify(&ctx).unwrap_err();
assert_eq!(err.law(), "bit_identical_cpu_reference");
assert!(err.message().starts_with("Fix: "));
assert!(err.message().contains("byte 1"));
}
#[test]
fn rejects_truncated_output() {
let spec = strict_exact_spec();
let ctx = VerifyCtx {
op: &spec,
backend_id: "test",
input: &[1, 2, 3, 4],
reference_output: &[7, 8, 9],
backend_output: &[7, 8],
};
let err = REGISTERED.verify(&ctx).unwrap_err();
assert!(err.message().contains("byte 2"));
}
}