use crate::pipeline::certify::Violation;
use crate::proof::oracles::{Oracle, VerifyCtx, VerifyResult};
use crate::spec::types::{OpSpec, OracleKind};
pub struct BoundedOutputOracle;
pub const REGISTERED: BoundedOutputOracle = BoundedOutputOracle;
impl Oracle for BoundedOutputOracle {
#[inline]
fn id(&self) -> &'static str {
"bounded_output"
}
#[inline]
fn kind(&self) -> OracleKind {
OracleKind::Property
}
#[inline]
fn applicable_to(&self, op: &OpSpec) -> bool {
op.expected_output_bytes.is_some() || op.signature.output.max_bytes().is_some()
}
#[inline]
fn verify(&self, ctx: &VerifyCtx<'_>) -> VerifyResult {
if let Some(expected) = ctx.op.expected_output_bytes {
return verify_exact_len(ctx, expected);
}
let min = ctx.op.signature.output.min_bytes();
let max = ctx.op.signature.output.max_bytes().expect(
"Fix: BoundedOutputOracle applicable_to must reject output types without a max bound.",
);
if (min..=max).contains(&ctx.backend_output.len()) {
return Ok(());
}
Err(length_violation(
ctx,
format!(
"Fix: make backend output length for {} stay within {}..={} bytes; got {} bytes.",
ctx.op.id,
min,
max,
ctx.backend_output.len()
),
))
}
}
fn verify_exact_len(ctx: &VerifyCtx<'_>, expected: usize) -> VerifyResult {
if ctx.backend_output.len() == expected {
return Ok(());
}
Err(length_violation(
ctx,
format!(
"Fix: make backend output length for {} exactly {} bytes; got {} bytes.",
ctx.op.id,
expected,
ctx.backend_output.len()
),
))
}
fn length_violation(ctx: &VerifyCtx<'_>, message: String) -> Violation {
Violation::new(
ctx.op.id.to_string(),
"bounded_output_length".to_string(),
ctx.backend_id.to_string(),
ctx.reference_output.to_vec(),
ctx.backend_output.to_vec(),
message,
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::spec::types::{DataType, OpSignature, Strictness};
use vyre_spec::Category;
fn cpu(input: &[u8]) -> Vec<u8> {
input.to_vec()
}
fn wgsl() -> String {
String::new()
}
fn base_spec(output: DataType, expected_output_bytes: Option<usize>) -> OpSpec {
let mut spec = OpSpec::builder("test.oracle.bounded_output")
.signature(OpSignature {
inputs: vec![DataType::Bytes],
output,
})
.cpu_fn(cpu)
.wgsl_fn(wgsl)
.category(Category::A {
composition_of: vec!["test.oracle.bounded_output"],
})
.laws(vec![])
.strictness(Strictness::Strict)
.version(1)
.build()
.expect("Fix: test spec must build");
spec.expected_output_bytes = expected_output_bytes;
spec
}
#[test]
fn accepts_exact_declared_length() {
let spec = base_spec(DataType::U32, Some(4));
let ctx = VerifyCtx {
op: &spec,
backend_id: "test",
input: &[1],
reference_output: &[0, 0, 0, 0],
backend_output: &[1, 2, 3, 4],
};
assert!(REGISTERED.verify(&ctx).is_ok());
}
#[test]
fn rejects_wrong_declared_length() {
let spec = base_spec(DataType::U32, Some(4));
let ctx = VerifyCtx {
op: &spec,
backend_id: "test",
input: &[1],
reference_output: &[0, 0, 0, 0],
backend_output: &[1, 2, 3],
};
let err = REGISTERED.verify(&ctx).unwrap_err();
assert_eq!(err.law(), "bounded_output_length");
assert!(err.message().starts_with("Fix: "));
assert!(err.message().contains("exactly 4 bytes"));
}
#[test]
fn rejects_output_larger_than_type_bound() {
let spec = base_spec(DataType::U32, None);
let ctx = VerifyCtx {
op: &spec,
backend_id: "test",
input: &[1],
reference_output: &[0, 0, 0, 0],
backend_output: &[0, 0, 0, 0, 0],
};
let err = REGISTERED.verify(&ctx).unwrap_err();
assert!(err.message().contains("4..=4 bytes"));
}
}