vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
use std::sync::atomic::{AtomicU64, Ordering};

use vyre_conform::backend::DispatchConfig;
use vyre_conform::spec::law::AlgebraicLaw;
use vyre_conform::{
    certify, CertificateStrength, OpOutcome, OpSignature, OpSpec, Strictness, VyreBackend,
};

static CPU_CALLS: AtomicU64 = AtomicU64::new(0);

struct MirrorBackend {
    cpu_fn: fn(&[u8]) -> Vec<u8>,
}

impl VyreBackend for MirrorBackend {
    fn name(&self) -> &str {
        "law-dispatch-mirror"
    }

    fn version(&self) -> &str {
        "test"
    }

    fn dispatch(
        &self,
        _wgsl: &str,
        input: &[u8],
        _output_size: usize,
        _config: DispatchConfig,
    ) -> Result<Vec<u8>, String> {
        Ok((self.cpu_fn)(input))
    }
}

fn pair(input: &[u8]) -> (u32, u32) {
    if input.len() < 8 {
        return (0, 0);
    }
    (
        u32::from_le_bytes([input[0], input[1], input[2], input[3]]),
        u32::from_le_bytes([input[4], input[5], input[6], input[7]]),
    )
}

fn counted(value: u32) -> Vec<u8> {
    CPU_CALLS.fetch_add(1, Ordering::Relaxed);
    value.to_le_bytes().to_vec()
}

fn non_commutative(input: &[u8]) -> Vec<u8> {
    let (a, b) = pair(input);
    counted(a.wrapping_sub(b))
}

fn non_associative(input: &[u8]) -> Vec<u8> {
    let (a, b) = pair(input);
    counted(a.wrapping_sub(b))
}

fn bad_identity(input: &[u8]) -> Vec<u8> {
    let (a, b) = pair(input);
    counted(if b == 0 { a ^ 1 } else { a.wrapping_add(b) })
}

fn bad_left_identity(input: &[u8]) -> Vec<u8> {
    let (a, b) = pair(input);
    counted(if a == 0 { b ^ 1 } else { a.wrapping_add(b) })
}

fn bad_right_identity(input: &[u8]) -> Vec<u8> {
    let (a, b) = pair(input);
    counted(if b == 0 { a ^ 1 } else { a.wrapping_add(b) })
}

fn bad_self_inverse(input: &[u8]) -> Vec<u8> {
    let (a, b) = pair(input);
    counted(a | b)
}

fn bad_idempotent(input: &[u8]) -> Vec<u8> {
    let (a, b) = pair(input);
    counted(a.wrapping_add(b))
}

fn bad_absorbing(input: &[u8]) -> Vec<u8> {
    let (a, b) = pair(input);
    counted(a | b)
}

fn bad_left_absorbing(input: &[u8]) -> Vec<u8> {
    let (a, b) = pair(input);
    counted(a.wrapping_add(b))
}

fn bad_right_absorbing(input: &[u8]) -> Vec<u8> {
    let (a, b) = pair(input);
    counted(a.wrapping_add(b))
}

fn bad_left_absorbing_nonzero_element(input: &[u8]) -> Vec<u8> {
    let (a, _b) = pair(input);
    counted(if a == 1 { 0 } else { a })
}

fn bad_right_absorbing_nonzero_element(input: &[u8]) -> Vec<u8> {
    let (_a, b) = pair(input);
    counted(if b == 1 { 0 } else { b })
}

fn bad_bounded(input: &[u8]) -> Vec<u8> {
    let (a, b) = pair(input);
    counted(a.wrapping_add(b))
}

fn bad_zero_product_true(input: &[u8]) -> Vec<u8> {
    let _ = pair(input);
    counted(0)
}

fn bad_zero_product_false(input: &[u8]) -> Vec<u8> {
    let (a, b) = pair(input);
    counted(a | b)
}

fn spec_for(id: &'static str, law: AlgebraicLaw, cpu_fn: fn(&[u8]) -> Vec<u8>) -> OpSpec {
    let mut spec = vyre_conform::specs::primitive::xor::spec();
    spec.id = id;
    spec.signature = OpSignature {
        inputs: vec![vyre_conform::DataType::U32, vyre_conform::DataType::U32],
        output: vyre_conform::DataType::U32,
    };
    spec.strictness = Strictness::Strict;
    spec.cpu_fn = cpu_fn;
    spec.laws = vec![law];
    spec
}

#[test]
fn certify_invokes_real_binary_law_checks() {
    let cases: &[(&str, AlgebraicLaw, fn(&[u8]) -> Vec<u8>)] = &[
        (
            "test.dispatch.commutative",
            AlgebraicLaw::Commutative,
            non_commutative,
        ),
        (
            "test.dispatch.associative",
            AlgebraicLaw::Associative,
            non_associative,
        ),
        (
            "test.dispatch.identity",
            AlgebraicLaw::Identity { element: 0 },
            bad_identity,
        ),
        (
            "test.dispatch.left_identity",
            AlgebraicLaw::LeftIdentity { element: 0 },
            bad_left_identity,
        ),
        (
            "test.dispatch.right_identity",
            AlgebraicLaw::RightIdentity { element: 0 },
            bad_right_identity,
        ),
        (
            "test.dispatch.self_inverse",
            AlgebraicLaw::SelfInverse { result: 0 },
            bad_self_inverse,
        ),
        (
            "test.dispatch.idempotent",
            AlgebraicLaw::Idempotent,
            bad_idempotent,
        ),
        (
            "test.dispatch.absorbing",
            AlgebraicLaw::Absorbing { element: 0 },
            bad_absorbing,
        ),
        (
            "test.dispatch.left_absorbing",
            AlgebraicLaw::LeftAbsorbing { element: 0 },
            bad_left_absorbing,
        ),
        (
            "test.dispatch.right_absorbing",
            AlgebraicLaw::RightAbsorbing { element: 1 },
            bad_right_absorbing,
        ),
        (
            "test.dispatch.left_absorbing_nonzero",
            AlgebraicLaw::LeftAbsorbing { element: 1 },
            bad_left_absorbing_nonzero_element,
        ),
        (
            "test.dispatch.right_absorbing_nonzero",
            AlgebraicLaw::RightAbsorbing { element: 1 },
            bad_right_absorbing_nonzero_element,
        ),
        (
            "test.dispatch.bounded",
            AlgebraicLaw::Bounded { lo: 0, hi: 1 },
            bad_bounded,
        ),
        (
            "test.dispatch.zero_product_true",
            AlgebraicLaw::ZeroProduct { holds: true },
            bad_zero_product_true,
        ),
        (
            "test.dispatch.zero_product_false",
            AlgebraicLaw::ZeroProduct { holds: false },
            bad_zero_product_false,
        ),
    ];

    for &(id, ref law, cpu_fn) in cases {
        CPU_CALLS.store(0, Ordering::Relaxed);
        let spec = spec_for(id, law.clone(), cpu_fn);
        let backend = MirrorBackend { cpu_fn };
        let cert = certify(&backend, &[spec], CertificateStrength::FastCheck)
            .expect("certificate should render rejected law evidence");

        assert_eq!(
            cert.ops()[0].outcome(),
            OpOutcome::Failed,
            "{id} passed with law dispatch bypassed"
        );
        assert!(
            !cert.ops()[0].laws_failed().is_empty(),
            "{id} did not record a {} law violation",
            law.name()
        );
        assert!(
            CPU_CALLS.load(Ordering::Relaxed) > 0,
            "{id} never executed its CPU reference"
        );
    }
}