vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Composition theorem witness verification.

use crate::spec::law::{canonical_law_id, AlgebraicLaw};
use crate::spec::types::{DataType, OpSpec};

/// Verify a composition theorem holds for specific ops by testing witnesses.
///
/// `outer_fn` and `inner_fn` are the CPU reference functions.
/// Returns the number of witnesses tested and any violation found.
#[inline]
pub fn verify_theorem(
    theorem: &CompositionTheorem,
    outer: &OpSpec,
    inner: &OpSpec,
    witness_count: u64,
) -> (u64, Option<String>) {
    use crate::proof::algebra::checker::support::{call_binary, call_unary, simple_rng};
    if witness_count == 0 {
        return (
            0,
            Some(
                "witness_count must be > 0. Fix: request at least one theorem witness.".to_string(),
            ),
        );
    }
    let mut rng = simple_rng(theorem.name, "theorem");
    let outer_fn = outer.cpu_fn;
    let inner_fn = inner.cpu_fn;

    fn arity_mismatch(theorem_name: &str, outer: &OpSpec, inner: &OpSpec) -> Option<String> {
        let outer_arity = outer.signature.inputs.len();
        let inner_arity = inner.signature.inputs.len();
        match theorem_name {
            "commutativity_preservation" | "identity_propagation" | "absorbing_short_circuit" => {
                if outer_arity != 2 || inner_arity != 2 {
                    return Some(format!(
                        "{theorem_name} requires binary outer and binary inner, got outer={outer_arity}, inner={inner_arity}"
                    ));
                }
            }
            "bounded_chain" | "involution_chain" => {
                if outer_arity != 1 || inner_arity != 1 {
                    return Some(format!(
                        "{theorem_name} requires unary outer and unary inner, got outer={outer_arity}, inner={inner_arity}"
                    ));
                }
            }
            "idempotent_collapse" => {
                if outer_arity != 2 || inner_arity != 1 {
                    return Some(format!(
                        "{theorem_name} requires binary outer and unary inner, got outer={outer_arity}, inner={inner_arity}"
                    ));
                }
            }
            _ => {}
        }
        None
    }

    if let Some(err) = arity_mismatch(theorem.name, outer, inner) {
        return (0, Some(err));
    }

    match theorem.name {
        "commutativity_preservation" => {
            for i in 0..witness_count {
                let a = rng.next_u32();
                let b = rng.next_u32();
                let c = rng.next_u32();
                let d = rng.next_u32();
                let lhs = match call_binary(inner_fn, a, b) {
                    Ok(v) => match call_binary(inner_fn, c, d) {
                        Ok(v2) => match call_binary(outer_fn, v, v2) {
                            Ok(v3) => v3,
                            Err(e) => return (i + 1, Some(e)),
                        },
                        Err(e) => return (i + 1, Some(e)),
                    },
                    Err(e) => return (i + 1, Some(e)),
                };
                let rhs = match call_binary(inner_fn, c, d) {
                    Ok(v) => match call_binary(inner_fn, a, b) {
                        Ok(v2) => match call_binary(outer_fn, v, v2) {
                            Ok(v3) => v3,
                            Err(e) => return (i + 1, Some(e)),
                        },
                        Err(e) => return (i + 1, Some(e)),
                    },
                    Err(e) => return (i + 1, Some(e)),
                };
                if lhs != rhs {
                    return (
                        i + 1,
                        Some(format!(
                            "commutativity_preservation violated: outer(inner({a},{b}), inner({c},{d}))={lhs}, outer(inner({c},{d}), inner({a},{b}))={rhs}"
                        )),
                    );
                }
            }
            (witness_count, None)
        }
        "identity_propagation" => {
            let element = theorem.inner_requires.iter().find_map(|law| {
                if let AlgebraicLaw::Identity { element } = law {
                    Some(*element)
                } else {
                    None
                }
            });
            let Some(e) = element else {
                return (
                    0,
                    Some("identity_propagation requires Identity law".to_string()),
                );
            };
            for i in 0..witness_count {
                let a = rng.next_u32();
                let b = rng.next_u32();
                let lhs = match call_binary(inner_fn, a, e) {
                    Ok(v) => match call_binary(outer_fn, v, b) {
                        Ok(v2) => v2,
                        Err(e) => return (i + 1, Some(e)),
                    },
                    Err(e) => return (i + 1, Some(e)),
                };
                let rhs = match call_binary(outer_fn, a, b) {
                    Ok(v) => v,
                    Err(e) => return (i + 1, Some(e)),
                };
                if lhs != rhs {
                    return (
                        i + 1,
                        Some(format!(
                            "identity_propagation violated: outer(inner({a},{e}), {b})={lhs}, outer({a}, {b})={rhs}"
                        )),
                    );
                }
            }
            (witness_count, None)
        }
        "absorbing_short_circuit" => {
            let element = theorem.inner_requires.iter().find_map(|law| {
                if let AlgebraicLaw::Absorbing { element } = law {
                    Some(*element)
                } else {
                    None
                }
            });
            let Some(z) = element else {
                return (
                    0,
                    Some("absorbing_short_circuit requires Absorbing law".to_string()),
                );
            };
            for i in 0..witness_count {
                let a = rng.next_u32();
                let b = rng.next_u32();
                let lhs = match call_binary(inner_fn, a, z) {
                    Ok(v) => match call_binary(outer_fn, v, b) {
                        Ok(v2) => v2,
                        Err(e) => return (i + 1, Some(e)),
                    },
                    Err(e) => return (i + 1, Some(e)),
                };
                let rhs = match call_binary(outer_fn, z, b) {
                    Ok(v) => v,
                    Err(e) => return (i + 1, Some(e)),
                };
                if lhs != rhs {
                    return (
                        i + 1,
                        Some(format!(
                            "absorbing_short_circuit violated: outer(inner({a},{z}), {b})={lhs}, outer({z}, {b})={rhs}"
                        )),
                    );
                }
            }
            (witness_count, None)
        }
        "bounded_chain" => {
            let bounds = theorem.inner_requires.iter().find_map(|law| {
                if let AlgebraicLaw::Bounded { lo, hi } = law {
                    Some((*lo, *hi))
                } else {
                    None
                }
            });
            let Some((lo, hi)) = bounds else {
                return (0, Some("bounded_chain requires Bounded law".to_string()));
            };
            for i in 0..witness_count {
                let a = rng.next_u32();
                let inner_out = match call_unary(inner_fn, a) {
                    Ok(v) => v,
                    Err(e) => return (i + 1, Some(e)),
                };
                let composed = match call_unary(outer_fn, inner_out) {
                    Ok(v) => v,
                    Err(e) => return (i + 1, Some(e)),
                };
                let g_lo = match call_unary(outer_fn, lo) {
                    Ok(v) => v,
                    Err(e) => return (i + 1, Some(e)),
                };
                let g_hi = match call_unary(outer_fn, hi) {
                    Ok(v) => v,
                    Err(e) => return (i + 1, Some(e)),
                };
                if outside_bounds(inner.signature.output.clone(), composed, g_lo, g_hi) {
                    return (
                        i + 1,
                        Some(format!(
                            "bounded_chain violated: outer(inner({a}))={composed}, not in [{g_lo}, {g_hi}]"
                        )),
                    );
                }
            }
            (witness_count, None)
        }
        "involution_chain" => {
            for i in 0..witness_count {
                let a = rng.next_u32();
                let ga = match call_unary(inner_fn, a) {
                    Ok(v) => v,
                    Err(e) => return (i + 1, Some(e)),
                };
                let fga = match call_unary(outer_fn, ga) {
                    Ok(v) => v,
                    Err(e) => return (i + 1, Some(e)),
                };
                let ffga = match call_unary(outer_fn, fga) {
                    Ok(v) => v,
                    Err(e) => return (i + 1, Some(e)),
                };
                if ffga != ga {
                    return (
                        i + 1,
                        Some(format!(
                            "involution_chain violated: outer(outer(inner({a})))={ffga}, inner({a})={ga}"
                        )),
                    );
                }
            }
            (witness_count, None)
        }
        "idempotent_collapse" => {
            for i in 0..witness_count {
                let a = rng.next_u32();
                let fa = match call_unary(inner_fn, a) {
                    Ok(v) => v,
                    Err(e) => return (i + 1, Some(e)),
                };
                let lhs = match call_binary(outer_fn, fa, fa) {
                    Ok(v) => match call_binary(outer_fn, v, fa) {
                        Ok(v2) => v2,
                        Err(e) => return (i + 1, Some(e)),
                    },
                    Err(e) => return (i + 1, Some(e)),
                };
                let rhs = match call_binary(outer_fn, fa, fa) {
                    Ok(v) => v,
                    Err(e) => return (i + 1, Some(e)),
                };
                if lhs != rhs {
                    return (
                        i + 1,
                        Some(format!(
                            "idempotent_collapse violated: outer(outer(inner({a}), inner({a})), inner({a}))={lhs}, outer(inner({a}), inner({a}))={rhs}"
                        )),
                    );
                }
            }
            (witness_count, None)
        }
        _ => (0, Some(format!("unknown theorem: {}", theorem.name))),
    }
}