vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Checker for `Complement`.

use crate::proof::algebra::checker::support::{
    call_binary, call_unary, engine_failure_violation, missing_companion_violation, simple_rng,
    violation,
};
use crate::spec::law::LawViolation;
use crate::{spec::op_registry, DataType};

use super::BOUNDARY_VALUES;

fn missing_unary(op_id: &str, companion: &str) -> LawViolation {
    missing_companion_violation(op_id, "Complement", companion, "unary", "u32 -> u32")
}

fn missing_binary(op_id: &str, companion: &str) -> LawViolation {
    missing_companion_violation(op_id, "Complement", companion, "binary", "u32,u32 -> u32")
}

/// Verify `f(a) + f(g(a)) = universe` over u8 plus u32 witnesses.
#[inline]
pub(crate) fn check_exhaustive_u8(
    op_id: &str,
    f: fn(&[u8]) -> Vec<u8>,
    complement_op: &str,
    universe: u32,
    witness_count: u64,
) -> Result<(u64, Option<LawViolation>), LawViolation> {
    let Some(g) = lookup_unary(complement_op) else {
        return Ok((0, Some(missing_unary(op_id, complement_op))));
    };
    let mut cases = 0u64;
    for a in 0u32..256 {
        cases += 1;
        if let Some(v) = check_case(op_id, f, g, universe, a)? {
            return Ok((cases, Some(v)));
        }
    }
    check_witnesses(op_id, f, g, universe, cases, witness_count)
}

/// Verify `f(a,b) + g(a,b) = universe` over u8 plus u32 witnesses.
#[inline]
pub(crate) fn check_binary_exhaustive_u8(
    op_id: &str,
    f: fn(&[u8]) -> Vec<u8>,
    complement_op: &str,
    universe: u32,
    witness_count: u64,
) -> Result<(u64, Option<LawViolation>), LawViolation> {
    let Some(g) = lookup_binary(complement_op) else {
        return Ok((0, Some(missing_binary(op_id, complement_op))));
    };
    let mut cases = 0u64;
    for a in 0u32..256 {
        for b in 0u32..256 {
            cases += 1;
            if let Some(v) = check_binary_case(op_id, f, g, universe, a, b)? {
                return Ok((cases, Some(v)));
            }
        }
    }
    check_binary_witnesses(op_id, f, g, universe, cases, witness_count)
}

/// Verify `f(a,b) + g(a,b) = universe` over deterministic u32 witnesses.
#[inline]
pub(crate) fn check_binary_witnessed_u32(
    op_id: &str,
    f: fn(&[u8]) -> Vec<u8>,
    complement_op: &str,
    universe: u32,
    count: u64,
) -> Result<(u64, Option<LawViolation>), LawViolation> {
    let Some(g) = lookup_binary(complement_op) else {
        return Ok((0, Some(missing_binary(op_id, complement_op))));
    };
    check_binary_witnesses(op_id, f, g, universe, 0, count)
}

/// Verify `f(a) + f(g(a)) = universe` over deterministic u32 witnesses.
#[inline]
pub(crate) fn check_witnessed_u32(
    op_id: &str,
    f: fn(&[u8]) -> Vec<u8>,
    complement_op: &str,
    universe: u32,
    count: u64,
) -> Result<(u64, Option<LawViolation>), LawViolation> {
    let Some(g) = lookup_unary(complement_op) else {
        return Ok((0, Some(missing_unary(op_id, complement_op))));
    };
    check_witnesses(op_id, f, g, universe, 0, count)
}

fn check_witnesses(
    op_id: &str,
    f: fn(&[u8]) -> Vec<u8>,
    g: fn(&[u8]) -> Vec<u8>,
    universe: u32,
    mut cases: u64,
    count: u64,
) -> Result<(u64, Option<LawViolation>), LawViolation> {
    for &a in BOUNDARY_VALUES {
        cases += 1;
        if let Some(v) = check_case(op_id, f, g, universe, a)? {
            return Ok((cases, Some(v)));
        }
    }

    let mut rng = simple_rng(op_id, "complement");
    for _ in 0..count {
        let a = rng.next_u32();
        cases += 1;
        if let Some(v) = check_case(op_id, f, g, universe, a)? {
            return Ok((cases, Some(v)));
        }
    }
    Ok((cases, None))
}

fn check_binary_witnesses(
    op_id: &str,
    f: fn(&[u8]) -> Vec<u8>,
    g: fn(&[u8]) -> Vec<u8>,
    universe: u32,
    mut cases: u64,
    count: u64,
) -> Result<(u64, Option<LawViolation>), LawViolation> {
    for &a in BOUNDARY_VALUES {
        for &b in BOUNDARY_VALUES {
            cases += 1;
            if let Some(v) = check_binary_case(op_id, f, g, universe, a, b)? {
                return Ok((cases, Some(v)));
            }
        }
    }

    let mut rng = simple_rng(op_id, "complement");
    for _ in 0..count {
        let a = rng.next_u32();
        let b = rng.next_u32();
        cases += 1;
        if let Some(v) = check_binary_case(op_id, f, g, universe, a, b)? {
            return Ok((cases, Some(v)));
        }
    }
    Ok((cases, None))
}

fn check_case(
    op_id: &str,
    f: fn(&[u8]) -> Vec<u8>,
    g: fn(&[u8]) -> Vec<u8>,
    universe: u32,
    a: u32,
) -> Result<Option<LawViolation>, LawViolation> {
    let fa = call_unary(f, a).map_err(|e| engine_failure_violation(op_id, e))?;
    let g_a = call_unary(g, a).map_err(|e| engine_failure_violation(op_id, e))?;
    let f_g_a = call_unary(f, g_a).map_err(|e| engine_failure_violation(op_id, e))?;
    let lhs = fa.wrapping_add(f_g_a);
    if lhs == universe {
        Ok(None)
    } else {
        Ok(Some(violation(op_id, "complement", a, 0, 0, lhs, universe)))
    }
}

fn check_binary_case(
    op_id: &str,
    f: fn(&[u8]) -> Vec<u8>,
    g: fn(&[u8]) -> Vec<u8>,
    universe: u32,
    a: u32,
    b: u32,
) -> Result<Option<LawViolation>, LawViolation> {
    let f_ab = call_binary(f, a, b).map_err(|e| engine_failure_violation(op_id, e))?;
    let g_ab = call_binary(g, a, b).map_err(|e| engine_failure_violation(op_id, e))?;
    let lhs = f_ab.wrapping_add(g_ab);
    if lhs == universe {
        Ok(None)
    } else {
        Ok(Some(violation(op_id, "complement", a, b, 0, lhs, universe)))
    }
}

fn lookup_unary(op_id: &str) -> Option<fn(&[u8]) -> Vec<u8>> {
    op_registry::all_specs()
        .into_iter()
        .find(|spec| {
            spec.id == op_id
                && spec.signature.inputs == [DataType::U32]
                && spec.signature.output == DataType::U32
        })
        .map(|spec| spec.cpu_fn)
}

fn lookup_binary(op_id: &str) -> Option<fn(&[u8]) -> Vec<u8>> {
    op_registry::all_specs()
        .into_iter()
        .find(|spec| {
            spec.id == op_id
                && spec.signature.inputs == [DataType::U32, DataType::U32]
                && spec.signature.output == DataType::U32
        })
        .map(|spec| spec.cpu_fn)
}