vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Checker for lattice absorption.

use crate::proof::algebra::checker::support::{
    call_binary, engine_failure_violation, missing_companion_violation, simple_rng, violation,
};
use crate::spec::law::LawViolation;
use crate::{spec::op_registry, DataType};

use super::BOUNDARY_VALUES;

/// Verify `f(a, g(a,b)) = a` over u8 plus u32 witnesses.
#[inline]
pub(crate) fn check_exhaustive_u8(
    op_id: &str,
    f: fn(&[u8]) -> Vec<u8>,
    dual_op: &str,
    witness_count: u64,
) -> (u64, Option<LawViolation>) {
    let Some(dual) = lookup_binary(dual_op) else {
        return (
            0,
            Some(missing_companion_violation(
                op_id,
                "LatticeAbsorption",
                dual_op,
                "dual",
                "u32,u32 -> u32",
            )),
        );
    };
    let mut cases = 0;
    for a in 0u32..256 {
        for b in 0u32..256 {
            cases += 1;
            if let Some(v) = check_case(op_id, f, dual, a, b) {
                return (cases, Some(v));
            }
        }
    }
    check_witnesses(op_id, f, dual, cases, witness_count)
}

/// Verify `f(a, g(a,b)) = a` over deterministic u32 witnesses.
#[inline]
pub(crate) fn check_witnessed_u32(
    op_id: &str,
    f: fn(&[u8]) -> Vec<u8>,
    dual_op: &str,
    count: u64,
) -> (u64, Option<LawViolation>) {
    let Some(dual) = lookup_binary(dual_op) else {
        return (
            0,
            Some(missing_companion_violation(
                op_id,
                "LatticeAbsorption",
                dual_op,
                "dual",
                "u32,u32 -> u32",
            )),
        );
    };
    check_witnesses(op_id, f, dual, 0, count)
}

fn check_witnesses(
    op_id: &str,
    f: fn(&[u8]) -> Vec<u8>,
    dual: fn(&[u8]) -> Vec<u8>,
    mut cases: u64,
    count: u64,
) -> (u64, Option<LawViolation>) {
    for &a in BOUNDARY_VALUES {
        for &b in BOUNDARY_VALUES {
            cases += 1;
            if let Some(v) = check_case(op_id, f, dual, a, b) {
                return (cases, Some(v));
            }
        }
    }
    let mut rng = simple_rng(op_id, "lattice-absorption");
    for _ in 0..count {
        let a = rng.next_u32();
        let b = rng.next_u32();
        cases += 1;
        if let Some(v) = check_case(op_id, f, dual, a, b) {
            return (cases, Some(v));
        }
    }
    (cases, None)
}

fn check_case(
    op_id: &str,
    f: fn(&[u8]) -> Vec<u8>,
    dual: fn(&[u8]) -> Vec<u8>,
    a: u32,
    b: u32,
) -> Option<LawViolation> {
    let dual_ab = match call_binary(dual, a, b) {
        Ok(v) => v,
        Err(e) => return Some(engine_failure_violation(op_id, e)),
    };
    let actual = match call_binary(f, a, dual_ab) {
        Ok(v) => v,
        Err(e) => return Some(engine_failure_violation(op_id, e)),
    };
    if actual != a {
        return Some(violation(op_id, "lattice-absorption", a, b, 0, actual, a));
    }

    // Dual absorption: g(a, f(a,b)) = a
    let f_ab = match call_binary(f, a, b) {
        Ok(v) => v,
        Err(e) => return Some(engine_failure_violation(op_id, e)),
    };
    let dual_actual = match call_binary(dual, a, f_ab) {
        Ok(v) => v,
        Err(e) => return Some(engine_failure_violation(op_id, e)),
    };
    if dual_actual != a {
        Some(violation(
            op_id,
            "lattice-absorption",
            a,
            b,
            0,
            dual_actual,
            a,
        ))
    } else {
        None
    }
}

fn lookup_binary(op_id: &str) -> Option<fn(&[u8]) -> Vec<u8>> {
    op_registry::all_specs()
        .into_iter()
        .find(|spec| {
            spec.id == op_id
                && spec.signature.inputs == [DataType::U32, DataType::U32]
                && spec.signature.output == DataType::U32
        })
        .map(|spec| spec.cpu_fn)
}