vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Rebuild cases for relation laws that depend on companion operations.

use crate::proof::algebra::checker::support::{
    call_binary, call_unary, engine_failure_violation, violation,
};
use crate::spec::law::{AlgebraicLaw, LawViolation, MonotonicDirection};
use crate::{spec::op_registry, DataType};

type OpFn = fn(&[u8]) -> Vec<u8>;

pub(super) fn rebuild_distributive(
    op_id: &str,
    f: OpFn,
    over_op: &str,
    a: u32,
    b: u32,
    c: u32,
) -> Result<Option<LawViolation>, LawViolation> {
    let over = lookup_binary_op(over_op)
        .ok_or_else(|| missing_rebuild_companion(op_id, "DistributiveOver", over_op, "binary"))?;
    let over_bc = call_binary(over, b, c).map_err(|e| engine_failure_violation(op_id, e))?;
    let lhs = call_binary(f, a, over_bc).map_err(|e| engine_failure_violation(op_id, e))?;
    let ab = call_binary(f, a, b).map_err(|e| engine_failure_violation(op_id, e))?;
    let ac = call_binary(f, a, c).map_err(|e| engine_failure_violation(op_id, e))?;
    let rhs = call_binary(over, ab, ac).map_err(|e| engine_failure_violation(op_id, e))?;
    if lhs != rhs {
        return Ok(Some(violation(op_id, "distributive", a, b, c, lhs, rhs)));
    }

    let over_ab = call_binary(over, a, b).map_err(|e| engine_failure_violation(op_id, e))?;
    let lhs = call_binary(f, over_ab, c).map_err(|e| engine_failure_violation(op_id, e))?;
    let ac = call_binary(f, a, c).map_err(|e| engine_failure_violation(op_id, e))?;
    let bc = call_binary(f, b, c).map_err(|e| engine_failure_violation(op_id, e))?;
    let rhs = call_binary(over, ac, bc).map_err(|e| engine_failure_violation(op_id, e))?;
    if lhs != rhs {
        Ok(Some(violation(op_id, "distributive", a, b, c, lhs, rhs)))
    } else {
        Ok(None)
    }
}

pub(super) fn rebuild_binary_complement(
    op_id: &str,
    f: OpFn,
    complement_op: &str,
    universe: u32,
    a: u32,
    b: u32,
) -> Result<Option<LawViolation>, LawViolation> {
    let complement = lookup_binary_op(complement_op)
        .ok_or_else(|| missing_rebuild_companion(op_id, "Complement", complement_op, "binary"))?;
    let lhs = call_binary(f, a, b)
        .map_err(|e| engine_failure_violation(op_id, e))?
        .wrapping_add(
            call_binary(complement, a, b).map_err(|e| engine_failure_violation(op_id, e))?,
        );
    if lhs != universe {
        Ok(Some(violation(op_id, "complement", a, b, 0, lhs, universe)))
    } else {
        Ok(None)
    }
}

pub(super) fn rebuild_lattice_absorption(
    op_id: &str,
    f: OpFn,
    dual_op: &str,
    a: u32,
    b: u32,
) -> Result<Option<LawViolation>, LawViolation> {
    let dual = lookup_binary_op(dual_op)
        .ok_or_else(|| missing_rebuild_companion(op_id, "LatticeAbsorption", dual_op, "dual"))?;
    let dual_ab = call_binary(dual, a, b).map_err(|e| engine_failure_violation(op_id, e))?;
    let lhs = call_binary(f, a, dual_ab).map_err(|e| engine_failure_violation(op_id, e))?;
    if lhs != a {
        return Ok(Some(violation(
            op_id,
            "lattice-absorption",
            a,
            b,
            0,
            lhs,
            a,
        )));
    }

    let ab = call_binary(f, a, b).map_err(|e| engine_failure_violation(op_id, e))?;
    let rhs_lhs = call_binary(dual, a, ab).map_err(|e| engine_failure_violation(op_id, e))?;
    if rhs_lhs != a {
        Ok(Some(violation(
            op_id,
            "lattice-absorption",
            a,
            b,
            0,
            rhs_lhs,
            a,
        )))
    } else {
        Ok(None)
    }
}

pub(super) fn rebuild_inverse(
    op_id: &str,
    f: OpFn,
    inverse_op: &str,
    a: u32,
    b: u32,
) -> Result<Option<LawViolation>, LawViolation> {
    let inverse = lookup_binary_op(inverse_op)
        .ok_or_else(|| missing_rebuild_companion(op_id, "InverseOf", inverse_op, "binary"))?;
    let inverse_ab = call_binary(inverse, a, b).map_err(|e| engine_failure_violation(op_id, e))?;
    let actual = call_binary(f, inverse_ab, b).map_err(|e| engine_failure_violation(op_id, e))?;
    if actual != a {
        Ok(Some(violation(op_id, "inverse-of", a, b, 0, actual, a)))
    } else {
        Ok(None)
    }
}

pub(super) fn rebuild_trichotomy(
    op_id: &str,
    f: OpFn,
    less_op: &str,
    equal_op: &str,
    greater_op: &str,
    a: u32,
    b: u32,
) -> Result<Option<LawViolation>, LawViolation> {
    let less = select_binary_op(op_id, f, less_op, "Trichotomy", "less")?;
    let equal = select_binary_op(op_id, f, equal_op, "Trichotomy", "equal")?;
    let greater = select_binary_op(op_id, f, greater_op, "Trichotomy", "greater")?;
    let lhs = call_binary(less, a, b)
        .map_err(|e| engine_failure_violation(op_id, e))?
        .wrapping_add(call_binary(equal, a, b).map_err(|e| engine_failure_violation(op_id, e))?)
        .wrapping_add(call_binary(greater, a, b).map_err(|e| engine_failure_violation(op_id, e))?);
    if lhs != 1 {
        Ok(Some(violation(op_id, "trichotomy", a, b, 0, lhs, 1)))
    } else {
        Ok(None)
    }
}

pub(super) fn rebuild_binary_custom(
    op_id: &str,
    f: OpFn,
    law: &AlgebraicLaw,
    a: u32,
    b: u32,
    c: u32,
) -> Result<Option<LawViolation>, LawViolation> {
    let AlgebraicLaw::Custom {
        name, arity, check, ..
    } = law
    else {
        return Ok(Some(unimplemented_rebuild(op_id, law, a, b, c)));
    };
    let args = match *arity {
        1 => [a][..].to_vec(),
        2 => [a, b][..].to_vec(),
        3 => [a, b, c][..].to_vec(),
        _ => {
            return Ok(Some(violation(
                op_id,
                &format!("custom({name})"),
                a,
                b,
                c,
                *arity as u32,
                3,
            )));
        }
    };
    if check(f, &args) {
        Ok(None)
    } else {
        Ok(Some(violation(
            op_id,
            &format!("custom({name})"),
            a,
            b,
            c,
            1,
            0,
        )))
    }
}

pub(super) fn rebuild_monotone(
    op_id: &str,
    f: OpFn,
    a: u32,
) -> Result<Option<LawViolation>, LawViolation> {
    let fa = call_unary(f, 0).map_err(|e| engine_failure_violation(op_id, e))?;
    let fb = call_unary(f, a).map_err(|e| engine_failure_violation(op_id, e))?;
    if fa > fb {
        Ok(Some(violation(op_id, "monotone", 0, a, 0, fa, fb)))
    } else {
        Ok(None)
    }
}

pub(super) fn rebuild_monotonic(
    op_id: &str,
    f: OpFn,
    direction: &MonotonicDirection,
    a: u32,
) -> Result<Option<LawViolation>, LawViolation> {
    let fa = call_unary(f, 0).map_err(|e| engine_failure_violation(op_id, e))?;
    let fb = call_unary(f, a).map_err(|e| engine_failure_violation(op_id, e))?;
    let ok = match direction {
        MonotonicDirection::NonDecreasing => fa <= fb,
        MonotonicDirection::NonIncreasing => fa >= fb,
        _ => false,
    };
    if ok {
        Ok(None)
    } else {
        Ok(Some(violation(op_id, "monotonic", 0, a, 0, fa, fb)))
    }
}

pub(super) fn rebuild_demorgan(
    op_id: &str,
    f: OpFn,
    inner_op: &str,
    dual_op: &str,
    a: u32,
) -> Result<Option<LawViolation>, LawViolation> {
    let inner = lookup_binary_op(inner_op)
        .ok_or_else(|| missing_rebuild_companion(op_id, "DeMorgan", inner_op, "inner"))?;
    let dual = lookup_binary_op(dual_op)
        .ok_or_else(|| missing_rebuild_companion(op_id, "DeMorgan", dual_op, "dual"))?;
    let inner_ab = call_binary(inner, a, 0).map_err(|e| engine_failure_violation(op_id, e))?;
    let lhs = call_unary(f, inner_ab).map_err(|e| engine_failure_violation(op_id, e))?;
    let fa = call_unary(f, a).map_err(|e| engine_failure_violation(op_id, e))?;
    let fb = call_unary(f, 0).map_err(|e| engine_failure_violation(op_id, e))?;
    let rhs = call_binary(dual, fa, fb).map_err(|e| engine_failure_violation(op_id, e))?;
    if lhs != rhs {
        Ok(Some(violation(op_id, "de-morgan", a, 0, 0, lhs, rhs)))
    } else {
        Ok(None)
    }
}

pub(super) fn rebuild_unary_complement(
    op_id: &str,
    f: OpFn,
    complement_op: &str,
    universe: u32,
    a: u32,
) -> Result<Option<LawViolation>, LawViolation> {
    let complement = lookup_unary_op(complement_op)
        .ok_or_else(|| missing_rebuild_companion(op_id, "Complement", complement_op, "unary"))?;
    let fa = call_unary(f, a).map_err(|e| engine_failure_violation(op_id, e))?;
    let complement_a = call_unary(complement, a).map_err(|e| engine_failure_violation(op_id, e))?;
    let f_complement_a =
        call_unary(f, complement_a).map_err(|e| engine_failure_violation(op_id, e))?;
    let lhs = fa.wrapping_add(f_complement_a);
    if lhs != universe {
        Ok(Some(violation(op_id, "complement", a, 0, 0, lhs, universe)))
    } else {
        Ok(None)
    }
}

pub(super) fn rebuild_unary_custom(
    op_id: &str,
    f: OpFn,
    law: &AlgebraicLaw,
    a: u32,
) -> Result<Option<LawViolation>, LawViolation> {
    let AlgebraicLaw::Custom {
        name, arity, check, ..
    } = law
    else {
        return Ok(Some(unimplemented_rebuild(op_id, law, a, 0, 0)));
    };
    if *arity != 1 {
        return Ok(Some(violation(
            op_id,
            &format!("custom({name})"),
            a,
            0,
            0,
            *arity as u32,
            1,
        )));
    }
    if check(f, &[a]) {
        Ok(None)
    } else {
        Ok(Some(violation(
            op_id,
            &format!("custom({name})"),
            a,
            0,
            0,
            1,
            0,
        )))
    }
}

pub(super) fn unimplemented_rebuild(
    op_id: &str,
    law: &AlgebraicLaw,
    a: u32,
    b: u32,
    c: u32,
) -> LawViolation {
    violation(
        op_id,
        &format!("unimplemented rebuild: {}", law.name()),
        a,
        b,
        c,
        0,
        1,
    )
}

fn lookup_binary_op(op_id: &str) -> Option<OpFn> {
    op_registry::all_specs()
        .into_iter()
        .find(|spec| {
            spec.id == op_id
                && spec.signature.inputs == [DataType::U32, DataType::U32]
                && spec.signature.output == DataType::U32
        })
        .map(|spec| spec.cpu_fn)
}

fn lookup_unary_op(op_id: &str) -> Option<OpFn> {
    op_registry::all_specs()
        .into_iter()
        .find(|spec| {
            spec.id == op_id
                && spec.signature.inputs == [DataType::U32]
                && spec.signature.output == DataType::U32
        })
        .map(|spec| spec.cpu_fn)
}

fn select_binary_op(
    op_id: &str,
    current: OpFn,
    requested: &str,
    law: &str,
    role: &str,
) -> Result<OpFn, LawViolation> {
    if op_id == requested {
        Ok(current)
    } else {
        lookup_binary_op(requested)
            .ok_or_else(|| missing_rebuild_companion(op_id, law, requested, role))
    }
}

fn missing_rebuild_companion(op_id: &str, law: &str, companion: &str, role: &str) -> LawViolation {
    LawViolation {
        law: law.to_string(),
        op_id: op_id.to_string(),
        a: 0,
        b: 0,
        c: 0,
        lhs: 0,
        rhs: 0,
        message: format!(
            "missing {role} companion op `{companion}` while rebuilding {law} violation on `{op_id}`. Fix: register the companion op or remove the declared law."
        ),
    }
}