use crate::proof::algebra::inference::{recommend, static_name, InferenceReport, InferredLaw};
use crate::spec::law::{AlgebraicLaw, LawViolation};
use crate::spec::types::DataType;
use crate::OpSpec;
#[inline]
pub fn infer_cross_op_laws(
op_id: &str,
cpu_fn: fn(&[u8]) -> Vec<u8>,
companions: &[OpSpec],
) -> InferenceReport {
let mut proven = Vec::new();
let mut findings = Vec::new();
let binary_companions: Vec<&OpSpec> = companions
.iter()
.filter(|spec| is_binary_u32(spec))
.collect();
let current = companions.iter().find(|spec| spec.id == op_id);
let current_static_id = current.map(|spec| spec.id);
let current_is_binary = current.is_none_or(is_binary_u32);
let current_is_unary = current.is_some_and(is_unary_u32);
if current_is_binary {
infer_binary_cross_laws(
op_id,
cpu_fn,
&binary_companions,
&mut proven,
&mut findings,
);
}
if current_is_unary {
infer_unary_cross_laws(
cpu_fn,
&binary_companions,
current_static_id,
&mut proven,
&mut findings,
);
}
if current_is_binary {
infer_trichotomy(
cpu_fn,
&binary_companions,
current_static_id,
&mut proven,
&mut findings,
);
}
let proven_inferred: Vec<InferredLaw> = proven
.into_iter()
.map(|law| {
let name = static_name(law.name());
let recommendation = recommend(&law);
InferredLaw {
law,
name,
recommendation,
}
})
.collect();
let proven_count = proven_inferred.len();
InferenceReport {
op_id: op_id.to_string(),
proven: proven_inferred,
skipped: Vec::new(),
findings,
summary: format!("{op_id}: inferred {proven_count} cross-op laws"),
}
}
fn infer_binary_cross_laws(
op_id: &str,
cpu_fn: fn(&[u8]) -> Vec<u8>,
binary_companions: &[&OpSpec],
proven: &mut Vec<AlgebraicLaw>,
findings: &mut Vec<LawViolation>,
) {
let Some(op_id) = binary_companions
.iter()
.find(|spec| spec.id == op_id)
.map(|spec| spec.id)
else {
return;
};
for companion in binary_companions {
if companion.id == op_id {
continue;
}
if test_distributive_over(cpu_fn, companion.cpu_fn, op_id, companion.id, findings) {
proven.push(AlgebraicLaw::DistributiveOver {
over_op: companion.id,
});
}
if test_binary_complement(
cpu_fn,
companion.cpu_fn,
op_id,
companion.id,
u32::MAX,
findings,
) {
proven.push(AlgebraicLaw::Complement {
complement_op: companion.id,
universe: u32::MAX,
});
}
if test_lattice_absorption(cpu_fn, companion.cpu_fn, op_id, companion.id, findings) {
proven.push(AlgebraicLaw::LatticeAbsorption {
dual_op: companion.id,
});
}
if test_inverse_of(cpu_fn, companion.cpu_fn, op_id, companion.id, findings) {
proven.push(AlgebraicLaw::InverseOf { op: companion.id });
}
}
}
fn infer_unary_cross_laws(
cpu_fn: fn(&[u8]) -> Vec<u8>,
binary_companions: &[&OpSpec],
current_static_id: Option<&'static str>,
proven: &mut Vec<AlgebraicLaw>,
findings: &mut Vec<LawViolation>,
) {
let Some(op_id) = current_static_id else {
return;
};
for inner in binary_companions {
for dual in binary_companions {
if inner.id == dual.id {
continue;
}
if test_de_morgan(
cpu_fn,
inner.cpu_fn,
dual.cpu_fn,
op_id,
inner.id,
dual.id,
findings,
) {
proven.push(AlgebraicLaw::DeMorgan {
inner_op: inner.id,
dual_op: dual.id,
});
}
}
}
}
fn infer_trichotomy(
cpu_fn: fn(&[u8]) -> Vec<u8>,
binary_companions: &[&OpSpec],
current_static_id: Option<&'static str>,
proven: &mut Vec<AlgebraicLaw>,
findings: &mut Vec<LawViolation>,
) {
let Some(op_id) = current_static_id else {
return;
};
for a in 0..binary_companions.len() {
for b in 0..binary_companions.len() {
if a == b {
continue;
}
let first = binary_companions[a];
let second = binary_companions[b];
if test_trichotomy(
cpu_fn,
first.cpu_fn,
second.cpu_fn,
op_id,
first.id,
second.id,
findings,
) {
proven.push(AlgebraicLaw::Trichotomy {
less_op: op_id,
equal_op: first.id,
greater_op: second.id,
});
}
if test_trichotomy(
first.cpu_fn,
cpu_fn,
second.cpu_fn,
first.id,
op_id,
second.id,
findings,
) {
proven.push(AlgebraicLaw::Trichotomy {
less_op: first.id,
equal_op: op_id,
greater_op: second.id,
});
}
if test_trichotomy(
first.cpu_fn,
second.cpu_fn,
cpu_fn,
first.id,
second.id,
op_id,
findings,
) {
proven.push(AlgebraicLaw::Trichotomy {
less_op: first.id,
equal_op: second.id,
greater_op: op_id,
});
}
}
}
}
fn is_binary_u32(spec: &OpSpec) -> bool {
spec.signature.inputs.len() == 2
&& spec.signature.inputs.iter().all(|ty| *ty == DataType::U32)
&& spec.signature.output == DataType::U32
}
fn is_unary_u32(spec: &OpSpec) -> bool {
spec.signature.inputs == [DataType::U32] && spec.signature.output == DataType::U32
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum ApplyError {
TruncatedOutput {
op_id: &'static str,
actual_len: usize,
},
}
fn apply_binary(
f: fn(&[u8]) -> Vec<u8>,
a: u32,
b: u32,
op_id: &'static str,
) -> Result<u32, ApplyError> {
let mut input = Vec::with_capacity(8);
input.extend_from_slice(&a.to_le_bytes());
input.extend_from_slice(&b.to_le_bytes());
read_u32_output(f(&input), op_id)
}
fn apply_unary(f: fn(&[u8]) -> Vec<u8>, a: u32, op_id: &'static str) -> Result<u32, ApplyError> {
read_u32_output(f(&a.to_le_bytes()), op_id)
}
fn read_u32_output(out: Vec<u8>, op_id: &'static str) -> Result<u32, ApplyError> {
if out.len() < 4 {
return Err(ApplyError::TruncatedOutput {
op_id,
actual_len: out.len(),
});
}
Ok(u32::from_le_bytes([out[0], out[1], out[2], out[3]]))
}
fn truncation_finding(err: ApplyError) -> LawViolation {
match err {
ApplyError::TruncatedOutput { op_id, actual_len } => LawViolation {
law: "engine-failure".to_string(),
op_id: op_id.to_string(),
a: 0,
b: 0,
c: 0,
lhs: 0,
rhs: 0,
message: format!(
"engine evaluation failed: op `{op_id}` output too short (expected 4 bytes, got {actual_len}). Fix: ensure the reference engine produces valid 4-byte u32 outputs."
),
},
}
}
fn test_distributive_over(
outer: fn(&[u8]) -> Vec<u8>,
inner: fn(&[u8]) -> Vec<u8>,
outer_id: &'static str,
inner_id: &'static str,
findings: &mut Vec<LawViolation>,
) -> bool {
let mut state = 0x5F37_59DF;
let mut tested = 0u32;
for _ in 0..256 {
let a = next_u32(&mut state);
let b = next_u32(&mut state);
let c = next_u32(&mut state);
let inner_bc = match apply_binary(inner, b, c, inner_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let lhs = match apply_binary(outer, a, inner_bc, outer_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let outer_ab = match apply_binary(outer, a, b, outer_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let outer_ac = match apply_binary(outer, a, c, outer_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let rhs = match apply_binary(inner, outer_ab, outer_ac, inner_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let lhs_right = match apply_binary(outer, inner_bc, a, outer_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let outer_ba = match apply_binary(outer, b, a, outer_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let outer_ca = match apply_binary(outer, c, a, outer_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let rhs_right = match apply_binary(inner, outer_ba, outer_ca, inner_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
tested += 1;
if lhs != rhs || lhs_right != rhs_right {
return false;
}
}
tested > 0
}
fn test_de_morgan(
unary: fn(&[u8]) -> Vec<u8>,
inner: fn(&[u8]) -> Vec<u8>,
dual: fn(&[u8]) -> Vec<u8>,
unary_id: &'static str,
inner_id: &'static str,
dual_id: &'static str,
findings: &mut Vec<LawViolation>,
) -> bool {
let mut state = 0xD1B5_4A32;
let mut tested = 0u32;
for _ in 0..256 {
let a = next_u32(&mut state);
let b = next_u32(&mut state);
let inner_ab = match apply_binary(inner, a, b, inner_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let lhs = match apply_unary(unary, inner_ab, unary_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let ua = match apply_unary(unary, a, unary_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let ub = match apply_unary(unary, b, unary_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let rhs = match apply_binary(dual, ua, ub, dual_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
tested += 1;
if lhs != rhs {
return false;
}
}
tested > 0
}
fn test_binary_complement(
f: fn(&[u8]) -> Vec<u8>,
g: fn(&[u8]) -> Vec<u8>,
f_id: &'static str,
g_id: &'static str,
universe: u32,
findings: &mut Vec<LawViolation>,
) -> bool {
let mut state = 0xC0FF_EE12;
let mut tested = 0u32;
for _ in 0..256 {
let a = next_u32(&mut state);
let b = next_u32(&mut state);
let fa = match apply_binary(f, a, b, f_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let ga = match apply_binary(g, a, b, g_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
tested += 1;
if fa.wrapping_add(ga) != universe {
return false;
}
}
tested > 0
}
fn test_lattice_absorption(
f: fn(&[u8]) -> Vec<u8>,
dual: fn(&[u8]) -> Vec<u8>,
f_id: &'static str,
dual_id: &'static str,
findings: &mut Vec<LawViolation>,
) -> bool {
let mut state = 0xA850_5AFE;
let mut tested = 0u32;
for _ in 0..256 {
let a = next_u32(&mut state);
let b = next_u32(&mut state);
let dual_ab = match apply_binary(dual, a, b, dual_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let f_absorbs = match apply_binary(f, a, dual_ab, f_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let f_ab = match apply_binary(f, a, b, f_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let dual_absorbs = match apply_binary(dual, a, f_ab, dual_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
tested += 1;
if f_absorbs != a || dual_absorbs != a {
return false;
}
}
tested > 0
}
fn test_inverse_of(
f: fn(&[u8]) -> Vec<u8>,
inverse: fn(&[u8]) -> Vec<u8>,
f_id: &'static str,
inverse_id: &'static str,
findings: &mut Vec<LawViolation>,
) -> bool {
let mut state = 0x1EAF_BEEF;
let mut tested = 0u32;
for _ in 0..256 {
let a = next_u32(&mut state);
let b = next_u32(&mut state);
let inv_ab = match apply_binary(inverse, a, b, inverse_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let result = match apply_binary(f, inv_ab, b, f_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
tested += 1;
if result != a {
return false;
}
}
tested > 0
}
fn test_trichotomy(
less: fn(&[u8]) -> Vec<u8>,
equal: fn(&[u8]) -> Vec<u8>,
greater: fn(&[u8]) -> Vec<u8>,
less_id: &'static str,
equal_id: &'static str,
greater_id: &'static str,
findings: &mut Vec<LawViolation>,
) -> bool {
let mut state = 0x7A1C_4010;
let mut tested = 0u32;
for _ in 0..256 {
let a = next_u32(&mut state);
let b = next_u32(&mut state);
let less_ab = match apply_binary(less, a, b, less_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let equal_ab = match apply_binary(equal, a, b, equal_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
let greater_ab = match apply_binary(greater, a, b, greater_id) {
Ok(v) => v,
Err(e) => {
findings.push(truncation_finding(e));
continue;
}
};
tested += 1;
if less_ab.wrapping_add(equal_ab).wrapping_add(greater_ab) != 1 {
return false;
}
}
tested > 0
}
fn next_u32(state: &mut u32) -> u32 {
*state ^= *state << 13;
*state ^= *state >> 17;
*state ^= *state << 5;
*state
}
#[cfg(test)]
mod tests {
use super::*;
use super::{apply_binary, apply_unary, ApplyError};
fn short_binary(_: &[u8]) -> Vec<u8> {
vec![1, 2]
}
fn short_unary(_: &[u8]) -> Vec<u8> {
vec![1]
}
fn good_binary(input: &[u8]) -> Vec<u8> {
let a = u32::from_le_bytes(input[..4].try_into().unwrap_or_default());
let b = u32::from_le_bytes(input[4..8].try_into().unwrap_or_default());
a.wrapping_add(b).to_le_bytes().to_vec()
}
#[test]
fn apply_binary_errors_on_truncated_output() {
let err = apply_binary(short_binary, 1, 2, "test.short").unwrap_err();
assert_eq!(
err,
ApplyError::TruncatedOutput {
op_id: "test.short",
actual_len: 2
}
);
}
#[test]
fn apply_unary_errors_on_truncated_output() {
let err = apply_unary(short_unary, 1, "test.short").unwrap_err();
assert_eq!(
err,
ApplyError::TruncatedOutput {
op_id: "test.short",
actual_len: 1
}
);
}
#[test]
fn apply_binary_ok_on_valid_output() {
assert_eq!(apply_binary(good_binary, 1, 2, "test.good"), Ok(3));
}
#[test]
fn test_distributive_over_handles_truncation() {
let mut findings = Vec::new();
assert!(!test_distributive_over(
good_binary,
short_binary,
"test.good",
"test.short",
&mut findings,
));
assert!(!findings.is_empty());
}
}