use super::{pair, primitive, result_u32, unary};
use crate::spec::law::AlgebraicLaw;
const SAMPLE: [u32; 8] = [
0,
1,
5,
0x7FFF_FFFF,
0x8000_0000,
0xDEAD_BEEF,
0xFFFF_FFFF - 1,
u32::MAX,
];
const ZERO_PRODUCT_COUNTEREXAMPLES: [(u32, u32); 4] = [
(2, 0x8000_0000),
(0x8000_0000, 2),
(0x0001_0000, 0x0001_0000),
(0xFFFF_0000, 0x0001_0000),
];
fn call_u32(spec: &crate::spec::types::OpSpec, input: &[u8]) -> u32 {
result_u32(&(spec.cpu_fn)(input))
}
fn apply_unary(f: fn(&[u8]) -> Vec<u8>, a: u32) -> u32 {
result_u32(&f(&unary(a)))
}
fn apply_binary(f: fn(&[u8]) -> Vec<u8>, a: u32, b: u32) -> u32 {
result_u32(&f(&pair(a, b)))
}
fn lookup_primitive_cpu(id: &str) -> fn(&[u8]) -> Vec<u8> {
let specs = primitive::specs();
let spec = specs.iter().find(|s| s.id == id).unwrap_or_else(|| {
panic!(
"primitive op '{id}' not in catalog; a declared law references an op that does not exist. \
Fix: register the referenced op or correct the law declaration."
)
});
spec.cpu_fn
}
fn is_unary_spec(spec: &crate::spec::types::OpSpec) -> bool {
spec.signature.inputs.len() == 1
}
fn is_float_spec(spec: &crate::spec::types::OpSpec) -> bool {
spec.signature
.inputs
.iter()
.any(|ty| matches!(ty, crate::spec::types::DataType::F32))
}
const FLOAT_SAMPLE: [u32; 8] = [
0x0000_0000, 0x3F80_0000, 0x4000_0000, 0x4080_0000, 0x4120_0000, 0xBF80_0000, 0xC000_0000, 0x4F80_0000, ];
fn samples_for(spec: &crate::spec::types::OpSpec) -> &'static [u32] {
if is_float_spec(spec) {
&FLOAT_SAMPLE
} else {
&SAMPLE
}
}
#[test]
fn all_declared_laws_hold_on_cpu_reference() {
for spec in primitive::specs() {
for law in &spec.laws {
verify_law(&spec, law);
}
}
}
fn verify_law(spec: &crate::spec::types::OpSpec, law: &AlgebraicLaw) {
match law {
AlgebraicLaw::Commutative => verify_commutative(spec),
AlgebraicLaw::Associative => verify_associative(spec),
AlgebraicLaw::Identity { element } => verify_identity(spec, *element),
AlgebraicLaw::SelfInverse { result } => verify_self_inverse(spec, *result),
AlgebraicLaw::Idempotent => verify_idempotent(spec),
AlgebraicLaw::Absorbing { element } => verify_absorbing(spec, *element),
AlgebraicLaw::Involution => verify_involution(spec),
AlgebraicLaw::Bounded { lo, hi } => verify_bounded(spec, *lo, *hi),
AlgebraicLaw::ZeroProduct { holds } => verify_zero_product(spec, *holds),
AlgebraicLaw::DeMorgan { inner_op, dual_op } => {
let inner = lookup_primitive_cpu(inner_op);
let dual = lookup_primitive_cpu(dual_op);
check_demorgan(spec.id, spec.cpu_fn, inner_op, inner, dual_op, dual);
}
AlgebraicLaw::Monotone => {
check_monotone(spec.id, spec.cpu_fn);
}
AlgebraicLaw::Complement {
complement_op,
universe,
} => {
let comp = lookup_primitive_cpu(complement_op);
let comp_is_binary = primitive::specs()
.iter()
.find(|s| s.id == *complement_op)
.is_some_and(|s| s.signature.inputs.len() == 2);
if is_unary_spec(spec) {
check_complement_unary(spec.id, spec.cpu_fn, complement_op, comp, *universe);
} else if comp_is_binary {
check_complement_binary_both(spec.id, spec.cpu_fn, complement_op, comp, *universe);
} else {
check_complement_binary(spec.id, spec.cpu_fn, complement_op, comp, *universe);
}
}
AlgebraicLaw::DistributiveOver { over_op } => {
let over = lookup_primitive_cpu(over_op);
check_distributive(spec.id, spec.cpu_fn, over_op, over);
}
AlgebraicLaw::Custom {
name, arity, check, ..
} => {
check_custom(spec.id, spec.cpu_fn, name, *arity, *check);
}
_ => {
}
}
}
fn verify_commutative(spec: &crate::spec::types::OpSpec) {
let samples = samples_for(spec);
for &a in samples {
for &b in samples {
let lhs = call_u32(spec, &pair(a, b));
let rhs = call_u32(spec, &pair(b, a));
assert_eq!(lhs, rhs, "{} violates Commutative at ({a}, {b})", spec.id);
}
}
}
fn verify_associative(spec: &crate::spec::types::OpSpec) {
for a in SAMPLE {
for b in SAMPLE {
for c in SAMPLE {
let ab = call_u32(spec, &pair(a, b));
let lhs = call_u32(spec, &pair(ab, c));
let bc = call_u32(spec, &pair(b, c));
let rhs = call_u32(spec, &pair(a, bc));
assert_eq!(
lhs, rhs,
"{} violates Associative at ({a}, {b}, {c})",
spec.id
);
}
}
}
}
fn verify_identity(spec: &crate::spec::types::OpSpec, element: u32) {
for a in SAMPLE {
let lhs = call_u32(spec, &pair(a, element));
assert_eq!(
lhs, a,
"{} violates Identity({element}) right at {a}",
spec.id
);
let rhs = call_u32(spec, &pair(element, a));
assert_eq!(
rhs, a,
"{} violates Identity({element}) left at {a}",
spec.id
);
}
}
fn verify_self_inverse(spec: &crate::spec::types::OpSpec, result: u32) {
for a in SAMPLE {
let val = call_u32(spec, &pair(a, a));
assert_eq!(
val, result,
"{} violates SelfInverse({result}) at {a}",
spec.id
);
}
}
fn verify_idempotent(spec: &crate::spec::types::OpSpec) {
if is_unary_spec(spec) {
for a in SAMPLE {
let once = call_u32(spec, &unary(a));
let twice = call_u32(spec, &unary(once));
assert_eq!(
twice, once,
"{} violates Idempotent at {a}: f(f(a))={twice:#010x} != f(a)={once:#010x}",
spec.id
);
}
return;
}
for a in SAMPLE {
let val = call_u32(spec, &pair(a, a));
assert_eq!(val, a, "{} violates Idempotent at {a}", spec.id);
}
}
fn verify_absorbing(spec: &crate::spec::types::OpSpec, element: u32) {
for a in SAMPLE {
let lhs = call_u32(spec, &pair(a, element));
assert_eq!(
lhs, element,
"{} violates Absorbing({element}) right at {a}",
spec.id
);
let rhs = call_u32(spec, &pair(element, a));
assert_eq!(
rhs, element,
"{} violates Absorbing({element}) left at {a}",
spec.id
);
}
}
fn verify_involution(spec: &crate::spec::types::OpSpec) {
for a in SAMPLE {
let once = call_u32(spec, &unary(a));
let twice = call_u32(spec, &unary(once));
assert_eq!(twice, a, "{} violates Involution at {a}", spec.id);
}
}
fn verify_bounded(spec: &crate::spec::types::OpSpec, lo: u32, hi: u32) {
if spec.signature.inputs.len() == 2 {
for a in SAMPLE {
for b in SAMPLE {
let val = call_u32(spec, &pair(a, b));
assert!(
lo <= val && val <= hi,
"{} violates Bounded([{lo}, {hi}]) at ({a}, {b}): got {val}",
spec.id
);
}
}
return;
}
for a in SAMPLE {
let val = call_u32(spec, &unary(a));
assert!(
lo <= val && val <= hi,
"{} violates Bounded([{lo}, {hi}]) at {a}: got {val}",
spec.id
);
}
}
fn verify_zero_product(spec: &crate::spec::types::OpSpec, holds: bool) {
if holds {
for a in SAMPLE {
for b in SAMPLE {
assert_zero_product_holds_for_pair(spec, a, b);
}
}
for (a, b) in ZERO_PRODUCT_COUNTEREXAMPLES {
assert_zero_product_holds_for_pair(spec, a, b);
}
return;
}
let found_counterexample = ZERO_PRODUCT_COUNTEREXAMPLES
.into_iter()
.chain(
SAMPLE
.into_iter()
.flat_map(|a| SAMPLE.into_iter().map(move |b| (a, b))),
)
.any(|(a, b)| a != 0 && b != 0 && call_u32(spec, &pair(a, b)) == 0);
assert!(
found_counterexample,
"{} declares ZeroProduct(false), but no non-zero sampled pair produced zero",
spec.id
);
}
fn assert_zero_product_holds_for_pair(spec: &crate::spec::types::OpSpec, a: u32, b: u32) {
let val = call_u32(spec, &pair(a, b));
assert!(
val != 0 || a == 0 || b == 0,
"{} violates ZeroProduct(true): f({a}, {b}) = 0 with both inputs non-zero",
spec.id
);
}
fn check_demorgan(
id: &str,
f_self: fn(&[u8]) -> Vec<u8>,
inner_name: &str,
f_inner: fn(&[u8]) -> Vec<u8>,
dual_name: &str,
f_dual: fn(&[u8]) -> Vec<u8>,
) {
for a in SAMPLE {
for b in SAMPLE {
let inner_ab = apply_binary(f_inner, a, b);
let lhs = apply_unary(f_self, inner_ab);
let f_a = apply_unary(f_self, a);
let f_b = apply_unary(f_self, b);
let rhs = apply_binary(f_dual, f_a, f_b);
assert_eq!(
lhs, rhs,
"{id} violates DeMorgan({inner_name}, {dual_name}) at ({a:#010x}, {b:#010x}): \
f(inner(a,b))={lhs:#010x} vs dual(f(a),f(b))={rhs:#010x}"
);
}
}
}
fn check_monotone(id: &str, f: fn(&[u8]) -> Vec<u8>) {
let mut sorted = SAMPLE;
sorted.sort_unstable();
for window in sorted.windows(2) {
let a = window[0];
let b = window[1];
let fa = apply_unary(f, a);
let fb = apply_unary(f, b);
assert!(
fa <= fb,
"{id} violates Monotone at ({a:#010x} <= {b:#010x}): \
f(a)={fa:#010x} > f(b)={fb:#010x}"
);
}
}
fn check_complement_binary(
id: &str,
f: fn(&[u8]) -> Vec<u8>,
comp_name: &str,
f_comp: fn(&[u8]) -> Vec<u8>,
universe: u32,
) {
for a in SAMPLE {
let comp_a = apply_unary(f_comp, a);
let combined = apply_binary(f, a, comp_a);
assert_eq!(
combined, universe,
"{id} violates Complement({comp_name}, universe={universe}) at {a:#010x}: \
f(a, {comp_name}(a)) = {combined:#010x}"
);
}
}
fn check_complement_binary_both(
id: &str,
f: fn(&[u8]) -> Vec<u8>,
comp_name: &str,
f_comp: fn(&[u8]) -> Vec<u8>,
universe: u32,
) {
for a in SAMPLE {
for b in SAMPLE {
let f_ab = apply_binary(f, a, b);
let comp_ab = apply_binary(f_comp, a, b);
assert_eq!(
f_ab + comp_ab,
universe,
"{id} violates Complement({comp_name}, universe={universe}) at ({a:#010x}, {b:#010x}): \
f(a,b) + {comp_name}(a,b) = {f_ab} + {comp_ab} = {combined}",
combined = f_ab + comp_ab,
);
}
}
}
fn check_complement_unary(
id: &str,
f: fn(&[u8]) -> Vec<u8>,
comp_name: &str,
f_comp: fn(&[u8]) -> Vec<u8>,
universe: u32,
) {
for a in SAMPLE {
let fa = apply_unary(f, a);
let comp_a = apply_unary(f_comp, a);
let f_comp_a = apply_unary(f, comp_a);
let sum = fa.wrapping_add(f_comp_a);
assert_eq!(
sum, universe,
"{id} violates Complement({comp_name}, universe={universe}) at {a:#010x}: \
f(a) + f({comp_name}(a)) = {sum}"
);
}
}
fn check_distributive(
id: &str,
f: fn(&[u8]) -> Vec<u8>,
over_name: &str,
f_over: fn(&[u8]) -> Vec<u8>,
) {
for a in SAMPLE {
for b in SAMPLE {
for c in SAMPLE {
let bc = apply_binary(f_over, b, c);
let lhs = apply_binary(f, a, bc);
let ab = apply_binary(f, a, b);
let ac = apply_binary(f, a, c);
let rhs = apply_binary(f_over, ab, ac);
assert_eq!(
lhs, rhs,
"{id} violates DistributiveOver({over_name}) at \
({a:#010x}, {b:#010x}, {c:#010x}): \
f(a,over(b,c))={lhs:#010x} vs over(f(a,b),f(a,c))={rhs:#010x}"
);
}
}
}
}
fn check_custom(
id: &str,
f: fn(&[u8]) -> Vec<u8>,
name: &str,
arity: usize,
check: crate::spec::law::LawCheckFn,
) {
for (i, args) in custom_witness_args(arity).into_iter().enumerate() {
let verdict = check(f, &args);
assert!(
verdict,
"{id} violates Custom({name}) at witness {i}: args={args:?}"
);
}
}
fn custom_witness_args(arity: usize) -> Vec<Vec<u32>> {
let mut out: Vec<Vec<u32>> = Vec::new();
for v in SAMPLE {
out.push(vec![v; arity.max(1)]);
}
let reversed: Vec<u32> = SAMPLE.iter().rev().copied().collect();
for (a, b) in SAMPLE.iter().zip(reversed.iter()) {
let mut v = Vec::with_capacity(arity.max(2));
v.push(*a);
v.push(*b);
while v.len() < arity {
v.push(a.wrapping_add(*b));
}
v.truncate(arity.max(2));
out.push(v);
}
out
}
fn fake_and(input: &[u8]) -> Vec<u8> {
let a = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
let b = u32::from_le_bytes([input[4], input[5], input[6], input[7]]);
(a & b).to_le_bytes().to_vec()
}
fn fake_or(input: &[u8]) -> Vec<u8> {
let a = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
let b = u32::from_le_bytes([input[4], input[5], input[6], input[7]]);
(a | b).to_le_bytes().to_vec()
}
fn fake_add(input: &[u8]) -> Vec<u8> {
let a = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
let b = u32::from_le_bytes([input[4], input[5], input[6], input[7]]);
a.wrapping_add(b).to_le_bytes().to_vec()
}
fn fake_mul(input: &[u8]) -> Vec<u8> {
let a = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
let b = u32::from_le_bytes([input[4], input[5], input[6], input[7]]);
a.wrapping_mul(b).to_le_bytes().to_vec()
}
fn fake_not(input: &[u8]) -> Vec<u8> {
let a = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
(!a).to_le_bytes().to_vec()
}
fn fake_identity_unary(input: &[u8]) -> Vec<u8> {
input[..4].to_vec()
}
fn fake_popcount(input: &[u8]) -> Vec<u8> {
let a = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
a.count_ones().to_le_bytes().to_vec()
}
fn fake_broken_not_returns_input(input: &[u8]) -> Vec<u8> {
input[..4].to_vec()
}
fn fake_broken_monotone_inverts(input: &[u8]) -> Vec<u8> {
let a = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
(u32::MAX - a).to_le_bytes().to_vec()
}
fn fake_broken_popcount_low_bit(input: &[u8]) -> Vec<u8> {
let a = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
(a & 1).to_le_bytes().to_vec()
}
fn fake_broken_mul_is_add(input: &[u8]) -> Vec<u8> {
let a = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
let b = u32::from_le_bytes([input[4], input[5], input[6], input[7]]);
a.wrapping_add(b).to_le_bytes().to_vec()
}
fn fake_custom_always_true(_f: fn(&[u8]) -> Vec<u8>, _args: &[u32]) -> bool {
true
}
fn fake_custom_always_false(_f: fn(&[u8]) -> Vec<u8>, _args: &[u32]) -> bool {
false
}
fn fake_custom_subtle_violation(f: fn(&[u8]) -> Vec<u8>, args: &[u32]) -> bool {
if args.len() >= 2 && args[0] == u32::MAX && args[1] == u32::MAX {
return false;
}
let _ = f; true
}
#[test]
fn demorgan_checker_catches_violation() {
let result = std::panic::catch_unwind(|| {
check_demorgan(
"fake.broken_not",
fake_broken_not_returns_input,
"and",
fake_and,
"or",
fake_or,
);
});
assert!(
result.is_err(),
"DeMorgan checker failed to catch a `not` impl that returns its input unchanged"
);
}
#[test]
fn demorgan_checker_accepts_valid_not() {
check_demorgan("fake.not", fake_not, "and", fake_and, "or", fake_or);
}
#[test]
fn monotone_checker_catches_violation() {
let result = std::panic::catch_unwind(|| {
check_monotone("fake.broken_monotone", fake_broken_monotone_inverts);
});
assert!(
result.is_err(),
"Monotone checker failed to catch an order-inverting impl"
);
}
#[test]
fn monotone_checker_accepts_identity() {
check_monotone("fake.identity", fake_identity_unary);
}
#[test]
fn complement_unary_checker_catches_violation() {
let result = std::panic::catch_unwind(|| {
check_complement_unary(
"fake.broken_popcount",
fake_broken_popcount_low_bit,
"primitive.bitwise.not",
fake_not,
32,
);
});
assert!(
result.is_err(),
"Complement checker (unary) failed to catch a popcount impl returning only the low bit"
);
}
#[test]
fn complement_unary_checker_accepts_valid_popcount() {
check_complement_unary(
"fake.popcount",
fake_popcount,
"primitive.bitwise.not",
fake_not,
32,
);
}
#[test]
fn complement_binary_checker_catches_violation() {
let result = std::panic::catch_unwind(|| {
check_complement_binary(
"fake.and_with_wrong_universe",
fake_and,
"primitive.bitwise.not",
fake_not,
0xFFFF_FFFF,
);
});
assert!(
result.is_err(),
"Complement checker (binary) failed to catch wrong universe for AND"
);
}
#[test]
fn complement_binary_checker_accepts_valid_or() {
check_complement_binary(
"fake.or",
fake_or,
"primitive.bitwise.not",
fake_not,
u32::MAX,
);
}
#[test]
fn distributive_checker_catches_violation() {
let result = std::panic::catch_unwind(|| {
check_distributive(
"fake.broken_mul",
fake_broken_mul_is_add,
"primitive.math.add",
fake_add,
);
});
assert!(
result.is_err(),
"DistributiveOver checker failed to catch an `add` posing as `mul`"
);
}
#[test]
fn distributive_checker_accepts_valid_mul() {
check_distributive("fake.mul", fake_mul, "primitive.math.add", fake_add);
}
#[test]
fn custom_checker_catches_always_false() {
let result = std::panic::catch_unwind(|| {
check_custom(
"fake.any",
fake_identity_unary,
"always-false",
1,
fake_custom_always_false,
);
});
assert!(
result.is_err(),
"Custom checker failed to reject a predicate that always returns false"
);
}
#[test]
fn custom_checker_accepts_always_true() {
check_custom(
"fake.any",
fake_identity_unary,
"always-true",
1,
fake_custom_always_true,
);
}
#[test]
fn custom_checker_catches_subtle_max_max_violation() {
let result = std::panic::catch_unwind(|| {
check_custom(
"fake.subtle",
fake_identity_unary,
"subtle-max-max-fail",
2,
fake_custom_subtle_violation,
);
});
assert!(
result.is_err(),
"Custom checker failed to catch a subtle violation at (u32::MAX, u32::MAX) — \
the checker's witness coverage is missing the top-of-range boundary"
);
}