use crate::spec::law::{canonical_law_id, AlgebraicLaw};
use crate::spec::types::{DataType, OpSpec};
#[derive(Debug, Clone)]
pub struct CompositionTheorem {
pub name: &'static str,
pub description: &'static str,
pub outer_requires: Vec<AlgebraicLaw>,
pub inner_requires: Vec<AlgebraicLaw>,
pub composition_guarantees: Vec<AlgebraicLaw>,
}
#[inline]
pub fn theorems() -> Vec<CompositionTheorem> {
vec![
CompositionTheorem {
name: "commutativity_preservation",
description:
"If g is commutative, then for any f, g(f(a,b), f(c,d)) = g(f(c,d), f(a,b)). \
The outer operation's commutativity is preserved regardless of the inner op.",
outer_requires: vec![AlgebraicLaw::Commutative],
inner_requires: vec![],
composition_guarantees: vec![],
},
CompositionTheorem {
name: "identity_propagation",
description: "If f has identity e_f and g has identity e_g, then \
g(f(a, e_f), anything) = g(a, anything). The inner identity \
simplifies the composition.",
outer_requires: vec![],
inner_requires: vec![AlgebraicLaw::Identity { element: 0 }],
composition_guarantees: vec![],
},
CompositionTheorem {
name: "absorbing_short_circuit",
description:
"If f has absorbing element z_f, then g(f(a, z_f), b) = g(z_f, b) for any g. \
The absorbing element of the inner op propagates through the outer op.",
outer_requires: vec![],
inner_requires: vec![AlgebraicLaw::Absorbing { element: 0 }],
composition_guarantees: vec![],
},
CompositionTheorem {
name: "bounded_chain",
description: "If f is bounded by [lo_f, hi_f] and g is monotone, then \
g(f(a)) is bounded by [g(lo_f), g(hi_f)]. Bounds compose \
through monotone functions.",
outer_requires: vec![AlgebraicLaw::Monotone],
inner_requires: vec![AlgebraicLaw::Bounded { lo: 0, hi: 0 }],
composition_guarantees: vec![],
},
CompositionTheorem {
name: "involution_chain",
description: "If f is an involution, then f(f(g(a))) = g(a). Applying an \
involution twice around any inner operation is a no-op.",
outer_requires: vec![AlgebraicLaw::Involution],
inner_requires: vec![],
composition_guarantees: vec![],
},
CompositionTheorem {
name: "idempotent_collapse",
description: "If g is idempotent, then g(g(f(a), f(a)), f(a)) = g(f(a), f(a)). \
Repeated application of an idempotent outer op collapses.",
outer_requires: vec![AlgebraicLaw::Idempotent],
inner_requires: vec![],
composition_guarantees: vec![],
},
]
}
#[inline]
pub fn applicable_theorem_instances(
outer_laws: &[AlgebraicLaw],
inner_laws: &[AlgebraicLaw],
) -> Vec<CompositionTheorem> {
let mut applicable = Vec::new();
for theorem in theorems() {
match theorem.name {
"identity_propagation" => {
for law in inner_laws
.iter()
.filter(|law| matches!(law, AlgebraicLaw::Identity { .. }))
{
let mut instance = theorem.clone();
instance.inner_requires = vec![law.clone()];
applicable.push(instance);
}
}
"absorbing_short_circuit" => {
for law in inner_laws
.iter()
.filter(|law| matches!(law, AlgebraicLaw::Absorbing { .. }))
{
let mut instance = theorem.clone();
instance.inner_requires = vec![law.clone()];
applicable.push(instance);
}
}
"bounded_chain" => {
if !outer_laws
.iter()
.any(|law| canonical_law_id(law) == canonical_law_id(&AlgebraicLaw::Monotone))
{
continue;
}
for law in inner_laws
.iter()
.filter(|law| matches!(law, AlgebraicLaw::Bounded { .. }))
{
let mut instance = theorem.clone();
instance.outer_requires = vec![AlgebraicLaw::Monotone];
instance.inner_requires = vec![law.clone()];
applicable.push(instance);
}
}
_ => {
if theorem_requirements_match(&theorem.outer_requires, outer_laws)
&& theorem_requirements_match(&theorem.inner_requires, inner_laws)
{
applicable.push(theorem);
}
}
}
}
applicable
}
#[inline]
pub fn verify_theorem(
theorem: &CompositionTheorem,
outer: &OpSpec,
inner: &OpSpec,
witness_count: u64,
) -> (u64, Option<String>) {
use crate::proof::algebra::checker::support::{call_binary, call_unary, simple_rng};
if witness_count == 0 {
return (
0,
Some(
"witness_count must be > 0. Fix: request at least one theorem witness.".to_string(),
),
);
}
let mut rng = simple_rng(theorem.name, "theorem");
let outer_fn = outer.cpu_fn;
let inner_fn = inner.cpu_fn;
fn arity_mismatch(theorem_name: &str, outer: &OpSpec, inner: &OpSpec) -> Option<String> {
let outer_arity = outer.signature.inputs.len();
let inner_arity = inner.signature.inputs.len();
match theorem_name {
"commutativity_preservation" | "identity_propagation" | "absorbing_short_circuit" => {
if outer_arity != 2 || inner_arity != 2 {
return Some(format!(
"{theorem_name} requires binary outer and binary inner, got outer={outer_arity}, inner={inner_arity}"
));
}
}
"bounded_chain" | "involution_chain" => {
if outer_arity != 1 || inner_arity != 1 {
return Some(format!(
"{theorem_name} requires unary outer and unary inner, got outer={outer_arity}, inner={inner_arity}"
));
}
}
"idempotent_collapse" => {
if outer_arity != 2 || inner_arity != 1 {
return Some(format!(
"{theorem_name} requires binary outer and unary inner, got outer={outer_arity}, inner={inner_arity}"
));
}
}
_ => {}
}
None
}
if let Some(err) = arity_mismatch(theorem.name, outer, inner) {
return (0, Some(err));
}
match theorem.name {
"commutativity_preservation" => {
for i in 0..witness_count {
let a = rng.next_u32();
let b = rng.next_u32();
let c = rng.next_u32();
let d = rng.next_u32();
let lhs = match call_binary(inner_fn, a, b) {
Ok(v) => match call_binary(inner_fn, c, d) {
Ok(v2) => match call_binary(outer_fn, v, v2) {
Ok(v3) => v3,
Err(e) => return (i + 1, Some(e)),
},
Err(e) => return (i + 1, Some(e)),
},
Err(e) => return (i + 1, Some(e)),
};
let rhs = match call_binary(inner_fn, c, d) {
Ok(v) => match call_binary(inner_fn, a, b) {
Ok(v2) => match call_binary(outer_fn, v, v2) {
Ok(v3) => v3,
Err(e) => return (i + 1, Some(e)),
},
Err(e) => return (i + 1, Some(e)),
},
Err(e) => return (i + 1, Some(e)),
};
if lhs != rhs {
return (
i + 1,
Some(format!(
"commutativity_preservation violated: outer(inner({a},{b}), inner({c},{d}))={lhs}, outer(inner({c},{d}), inner({a},{b}))={rhs}"
)),
);
}
}
(witness_count, None)
}
"identity_propagation" => {
let element = theorem.inner_requires.iter().find_map(|law| {
if let AlgebraicLaw::Identity { element } = law {
Some(*element)
} else {
None
}
});
let Some(e) = element else {
return (
0,
Some("identity_propagation requires Identity law".to_string()),
);
};
for i in 0..witness_count {
let a = rng.next_u32();
let b = rng.next_u32();
let lhs = match call_binary(inner_fn, a, e) {
Ok(v) => match call_binary(outer_fn, v, b) {
Ok(v2) => v2,
Err(e) => return (i + 1, Some(e)),
},
Err(e) => return (i + 1, Some(e)),
};
let rhs = match call_binary(outer_fn, a, b) {
Ok(v) => v,
Err(e) => return (i + 1, Some(e)),
};
if lhs != rhs {
return (
i + 1,
Some(format!(
"identity_propagation violated: outer(inner({a},{e}), {b})={lhs}, outer({a}, {b})={rhs}"
)),
);
}
}
(witness_count, None)
}
"absorbing_short_circuit" => {
let element = theorem.inner_requires.iter().find_map(|law| {
if let AlgebraicLaw::Absorbing { element } = law {
Some(*element)
} else {
None
}
});
let Some(z) = element else {
return (
0,
Some("absorbing_short_circuit requires Absorbing law".to_string()),
);
};
for i in 0..witness_count {
let a = rng.next_u32();
let b = rng.next_u32();
let lhs = match call_binary(inner_fn, a, z) {
Ok(v) => match call_binary(outer_fn, v, b) {
Ok(v2) => v2,
Err(e) => return (i + 1, Some(e)),
},
Err(e) => return (i + 1, Some(e)),
};
let rhs = match call_binary(outer_fn, z, b) {
Ok(v) => v,
Err(e) => return (i + 1, Some(e)),
};
if lhs != rhs {
return (
i + 1,
Some(format!(
"absorbing_short_circuit violated: outer(inner({a},{z}), {b})={lhs}, outer({z}, {b})={rhs}"
)),
);
}
}
(witness_count, None)
}
"bounded_chain" => {
let bounds = theorem.inner_requires.iter().find_map(|law| {
if let AlgebraicLaw::Bounded { lo, hi } = law {
Some((*lo, *hi))
} else {
None
}
});
let Some((lo, hi)) = bounds else {
return (0, Some("bounded_chain requires Bounded law".to_string()));
};
for i in 0..witness_count {
let a = rng.next_u32();
let inner_out = match call_unary(inner_fn, a) {
Ok(v) => v,
Err(e) => return (i + 1, Some(e)),
};
let composed = match call_unary(outer_fn, inner_out) {
Ok(v) => v,
Err(e) => return (i + 1, Some(e)),
};
let g_lo = match call_unary(outer_fn, lo) {
Ok(v) => v,
Err(e) => return (i + 1, Some(e)),
};
let g_hi = match call_unary(outer_fn, hi) {
Ok(v) => v,
Err(e) => return (i + 1, Some(e)),
};
if outside_bounds(inner.signature.output.clone(), composed, g_lo, g_hi) {
return (
i + 1,
Some(format!(
"bounded_chain violated: outer(inner({a}))={composed}, not in [{g_lo}, {g_hi}]"
)),
);
}
}
(witness_count, None)
}
"involution_chain" => {
for i in 0..witness_count {
let a = rng.next_u32();
let ga = match call_unary(inner_fn, a) {
Ok(v) => v,
Err(e) => return (i + 1, Some(e)),
};
let fga = match call_unary(outer_fn, ga) {
Ok(v) => v,
Err(e) => return (i + 1, Some(e)),
};
let ffga = match call_unary(outer_fn, fga) {
Ok(v) => v,
Err(e) => return (i + 1, Some(e)),
};
if ffga != ga {
return (
i + 1,
Some(format!(
"involution_chain violated: outer(outer(inner({a})))={ffga}, inner({a})={ga}"
)),
);
}
}
(witness_count, None)
}
"idempotent_collapse" => {
for i in 0..witness_count {
let a = rng.next_u32();
let fa = match call_unary(inner_fn, a) {
Ok(v) => v,
Err(e) => return (i + 1, Some(e)),
};
let lhs = match call_binary(outer_fn, fa, fa) {
Ok(v) => match call_binary(outer_fn, v, fa) {
Ok(v2) => v2,
Err(e) => return (i + 1, Some(e)),
},
Err(e) => return (i + 1, Some(e)),
};
let rhs = match call_binary(outer_fn, fa, fa) {
Ok(v) => v,
Err(e) => return (i + 1, Some(e)),
};
if lhs != rhs {
return (
i + 1,
Some(format!(
"idempotent_collapse violated: outer(outer(inner({a}), inner({a})), inner({a}))={lhs}, outer(inner({a}), inner({a}))={rhs}"
)),
);
}
}
(witness_count, None)
}
_ => (0, Some(format!("unknown theorem: {}", theorem.name))),
}
}
fn outside_bounds(output_type: DataType, value: u32, lo: u32, hi: u32) -> bool {
match output_type {
DataType::I32 => {
let value = value as i32;
let lo = lo as i32;
let hi = hi as i32;
value < lo || value > hi
}
_ => value < lo || value > hi,
}
}
#[inline]
pub fn applicable_theorems(
outer_laws: &[AlgebraicLaw],
inner_laws: &[AlgebraicLaw],
) -> Vec<&'static str> {
let mut applicable = Vec::new();
for theorem in applicable_theorem_instances(outer_laws, inner_laws) {
if !applicable.contains(&theorem.name) {
applicable.push(theorem.name);
}
}
applicable
}
fn theorem_requirements_match(required: &[AlgebraicLaw], actual: &[AlgebraicLaw]) -> bool {
required.iter().all(|req| {
actual
.iter()
.any(|law| canonical_law_id(law) == canonical_law_id(req))
})
}
#[cfg(test)]
mod tests {
use super::{outside_bounds, verify_theorem, CompositionTheorem};
use crate::spec::types::DataType;
#[test]
fn theorem_verification_rejects_zero_witnesses() {
let spec = crate::spec::primitive::add::spec();
let theorem = CompositionTheorem {
name: "commutativity_preservation",
description: "test theorem",
outer_requires: Vec::new(),
inner_requires: Vec::new(),
composition_guarantees: Vec::new(),
};
let (witnesses, violation) = verify_theorem(&theorem, &spec, &spec, 0);
assert_eq!(witnesses, 0);
assert!(matches!(violation, Some(message) if message.contains("witness_count")));
}
#[test]
fn bounded_chain_comparison_honors_signed_i32_bounds() {
assert!(!outside_bounds(
DataType::I32,
(-1i32) as u32,
i32::MIN as u32,
0
));
assert!(outside_bounds(DataType::I32, 1, i32::MIN as u32, 0));
}
}