vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Checker for comparison trichotomy.

use crate::proof::algebra::checker::support::{
    call_binary, engine_failure_violation, missing_companion_violation, simple_rng, violation,
};
use crate::spec::law::LawViolation;
use crate::{spec::op_registry, DataType};

use super::BOUNDARY_VALUES;

/// Verify `lt(a,b) + eq(a,b) + gt(a,b) = 1` over u8 plus u32 witnesses.
#[inline]
pub(crate) fn check_exhaustive_u8(
    op_id: &str,
    f: fn(&[u8]) -> Vec<u8>,
    less_op: &str,
    equal_op: &str,
    greater_op: &str,
    witness_count: u64,
) -> Result<(u64, Option<LawViolation>), LawViolation> {
    let ops = match ops_for(op_id, f, less_op, equal_op, greater_op) {
        Ok(o) => o,
        Err(v) => return Ok((0, Some(v))),
    };
    let mut cases = 0;
    for a in 0u32..256 {
        for b in 0u32..256 {
            cases += 1;
            if let Some(v) = check_case(op_id, ops, a, b)? {
                return Ok((cases, Some(v)));
            }
        }
    }
    check_witnesses(op_id, ops, cases, witness_count)
}

/// Verify `lt(a,b) + eq(a,b) + gt(a,b) = 1` over deterministic u32 witnesses.
#[inline]
pub(crate) fn check_witnessed_u32(
    op_id: &str,
    f: fn(&[u8]) -> Vec<u8>,
    less_op: &str,
    equal_op: &str,
    greater_op: &str,
    count: u64,
) -> Result<(u64, Option<LawViolation>), LawViolation> {
    let ops = match ops_for(op_id, f, less_op, equal_op, greater_op) {
        Ok(o) => o,
        Err(v) => return Ok((0, Some(v))),
    };
    check_witnesses(op_id, ops, 0, count)
}

fn check_witnesses(
    op_id: &str,
    ops: CompareOps,
    mut cases: u64,
    count: u64,
) -> Result<(u64, Option<LawViolation>), LawViolation> {
    for &a in BOUNDARY_VALUES {
        for &b in BOUNDARY_VALUES {
            cases += 1;
            if let Some(v) = check_case(op_id, ops, a, b)? {
                return Ok((cases, Some(v)));
            }
        }
    }
    let mut rng = simple_rng(op_id, "trichotomy");
    for _ in 0..count {
        let a = rng.next_u32();
        let b = rng.next_u32();
        cases += 1;
        if let Some(v) = check_case(op_id, ops, a, b)? {
            return Ok((cases, Some(v)));
        }
    }
    Ok((cases, None))
}

fn check_case(
    op_id: &str,
    ops: CompareOps,
    a: u32,
    b: u32,
) -> Result<Option<LawViolation>, LawViolation> {
    let less = call_binary(ops.less, a, b).map_err(|e| engine_failure_violation(op_id, e))?;
    let equal = call_binary(ops.equal, a, b).map_err(|e| engine_failure_violation(op_id, e))?;
    let greater = call_binary(ops.greater, a, b).map_err(|e| engine_failure_violation(op_id, e))?;
    let sum = less.wrapping_add(equal).wrapping_add(greater);
    if sum == 1 {
        Ok(None)
    } else {
        Ok(Some(violation(op_id, "trichotomy", a, b, 0, sum, 1)))
    }
}

#[derive(Clone, Copy)]
struct CompareOps {
    less: fn(&[u8]) -> Vec<u8>,
    equal: fn(&[u8]) -> Vec<u8>,
    greater: fn(&[u8]) -> Vec<u8>,
}

fn ops_for(
    op_id: &str,
    f: fn(&[u8]) -> Vec<u8>,
    less_op: &str,
    equal_op: &str,
    greater_op: &str,
) -> Result<CompareOps, LawViolation> {
    if op_id != less_op && op_id != equal_op && op_id != greater_op {
        return Err(LawViolation {
            law: "Trichotomy".to_string(),
            op_id: op_id.to_string(),
            a: 0,
            b: 0,
            c: 0,
            lhs: 0,
            rhs: 0,
            message: format!(
                "Trichotomy declared on `{op_id}` but it is not one of `{less_op}`, \
                 `{equal_op}`, `{greater_op}`. Fix: declare Trichotomy only on a \
                 participating comparison op."
            ),
        });
    }
    let less = if op_id == less_op {
        f
    } else {
        lookup_binary(less_op).ok_or_else(|| {
            missing_companion_violation(
                op_id,
                "Trichotomy",
                less_op,
                "comparison",
                "u32,u32 -> u32",
            )
        })?
    };
    let equal = if op_id == equal_op {
        f
    } else {
        lookup_binary(equal_op).ok_or_else(|| {
            missing_companion_violation(
                op_id,
                "Trichotomy",
                equal_op,
                "comparison",
                "u32,u32 -> u32",
            )
        })?
    };
    let greater = if op_id == greater_op {
        f
    } else {
        lookup_binary(greater_op).ok_or_else(|| {
            missing_companion_violation(
                op_id,
                "Trichotomy",
                greater_op,
                "comparison",
                "u32,u32 -> u32",
            )
        })?
    };
    Ok(CompareOps {
        less,
        equal,
        greater,
    })
}

fn lookup_binary(op_id: &str) -> Option<fn(&[u8]) -> Vec<u8>> {
    op_registry::all_specs()
        .into_iter()
        .find(|spec| {
            spec.id == op_id
                && spec.signature.inputs == [DataType::U32, DataType::U32]
                && spec.signature.output == DataType::U32
        })
        .map(|spec| spec.cpu_fn)
}