use crate::proof::algebra::checker::support::{
call_binary, engine_failure_violation, missing_companion_violation, simple_rng, violation,
};
use crate::spec::law::LawViolation;
use crate::{spec::op_registry, DataType};
use super::BOUNDARY_VALUES;
#[inline]
pub(crate) fn check_exhaustive_u8(
op_id: &str,
f: fn(&[u8]) -> Vec<u8>,
less_op: &str,
equal_op: &str,
greater_op: &str,
witness_count: u64,
) -> Result<(u64, Option<LawViolation>), LawViolation> {
let ops = match ops_for(op_id, f, less_op, equal_op, greater_op) {
Ok(o) => o,
Err(v) => return Ok((0, Some(v))),
};
let mut cases = 0;
for a in 0u32..256 {
for b in 0u32..256 {
cases += 1;
if let Some(v) = check_case(op_id, ops, a, b)? {
return Ok((cases, Some(v)));
}
}
}
check_witnesses(op_id, ops, cases, witness_count)
}
#[inline]
pub(crate) fn check_witnessed_u32(
op_id: &str,
f: fn(&[u8]) -> Vec<u8>,
less_op: &str,
equal_op: &str,
greater_op: &str,
count: u64,
) -> Result<(u64, Option<LawViolation>), LawViolation> {
let ops = match ops_for(op_id, f, less_op, equal_op, greater_op) {
Ok(o) => o,
Err(v) => return Ok((0, Some(v))),
};
check_witnesses(op_id, ops, 0, count)
}
fn check_witnesses(
op_id: &str,
ops: CompareOps,
mut cases: u64,
count: u64,
) -> Result<(u64, Option<LawViolation>), LawViolation> {
for &a in BOUNDARY_VALUES {
for &b in BOUNDARY_VALUES {
cases += 1;
if let Some(v) = check_case(op_id, ops, a, b)? {
return Ok((cases, Some(v)));
}
}
}
let mut rng = simple_rng(op_id, "trichotomy");
for _ in 0..count {
let a = rng.next_u32();
let b = rng.next_u32();
cases += 1;
if let Some(v) = check_case(op_id, ops, a, b)? {
return Ok((cases, Some(v)));
}
}
Ok((cases, None))
}
fn check_case(
op_id: &str,
ops: CompareOps,
a: u32,
b: u32,
) -> Result<Option<LawViolation>, LawViolation> {
let less = call_binary(ops.less, a, b).map_err(|e| engine_failure_violation(op_id, e))?;
let equal = call_binary(ops.equal, a, b).map_err(|e| engine_failure_violation(op_id, e))?;
let greater = call_binary(ops.greater, a, b).map_err(|e| engine_failure_violation(op_id, e))?;
let sum = less.wrapping_add(equal).wrapping_add(greater);
if sum == 1 {
Ok(None)
} else {
Ok(Some(violation(op_id, "trichotomy", a, b, 0, sum, 1)))
}
}
#[derive(Clone, Copy)]
struct CompareOps {
less: fn(&[u8]) -> Vec<u8>,
equal: fn(&[u8]) -> Vec<u8>,
greater: fn(&[u8]) -> Vec<u8>,
}
fn ops_for(
op_id: &str,
f: fn(&[u8]) -> Vec<u8>,
less_op: &str,
equal_op: &str,
greater_op: &str,
) -> Result<CompareOps, LawViolation> {
if op_id != less_op && op_id != equal_op && op_id != greater_op {
return Err(LawViolation {
law: "Trichotomy".to_string(),
op_id: op_id.to_string(),
a: 0,
b: 0,
c: 0,
lhs: 0,
rhs: 0,
message: format!(
"Trichotomy declared on `{op_id}` but it is not one of `{less_op}`, \
`{equal_op}`, `{greater_op}`. Fix: declare Trichotomy only on a \
participating comparison op."
),
});
}
let less = if op_id == less_op {
f
} else {
lookup_binary(less_op).ok_or_else(|| {
missing_companion_violation(
op_id,
"Trichotomy",
less_op,
"comparison",
"u32,u32 -> u32",
)
})?
};
let equal = if op_id == equal_op {
f
} else {
lookup_binary(equal_op).ok_or_else(|| {
missing_companion_violation(
op_id,
"Trichotomy",
equal_op,
"comparison",
"u32,u32 -> u32",
)
})?
};
let greater = if op_id == greater_op {
f
} else {
lookup_binary(greater_op).ok_or_else(|| {
missing_companion_violation(
op_id,
"Trichotomy",
greater_op,
"comparison",
"u32,u32 -> u32",
)
})?
};
Ok(CompareOps {
less,
equal,
greater,
})
}
fn lookup_binary(op_id: &str) -> Option<fn(&[u8]) -> Vec<u8>> {
op_registry::all_specs()
.into_iter()
.find(|spec| {
spec.id == op_id
&& spec.signature.inputs == [DataType::U32, DataType::U32]
&& spec.signature.output == DataType::U32
})
.map(|spec| spec.cpu_fn)
}