vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
use vyre::ir::{BinOp, BufferDecl, DataType, Expr, Node, Program};
use vyre_conform::{reference::interp, spec::value::Value, specs::primitive};

const CASES: usize = 10_000;

#[test]
fn reference_agrees_with_every_primitive_cpu_fn_for_witnessed_u32_inputs() {
    for spec in primitive::specs() {
        let arity = spec.signature.inputs.len();
        assert!(
            arity == 1 || arity == 2,
            "primitive agreement harness supports unary/binary u32 specs; {} has arity {arity}",
            spec.id
        );

        let input_words = witnessed_words(spec.id, arity, CASES);
        let input_bytes = words_to_bytes(&input_words);
        let output_init = vec![0u8; CASES * 4];
        let program = primitive_program(spec.id, arity);
        let outputs = interp::run(
            &program,
            &[Value::Bytes(input_bytes.clone()), Value::Bytes(output_init)],
        )
        .unwrap_or_else(|error| panic!("reference interpreter failed for {}: {error}", spec.id));
        let Value::Bytes(actual) = &outputs[0] else {
            panic!("reference output for {} was not bytes", spec.id);
        };

        for case in 0..CASES {
            let start = case * arity * 4;
            let end = start + arity * 4;
            let expected = (spec.cpu_fn)(&input_bytes[start..end]);
            let actual = &actual[case * 4..case * 4 + 4];
            assert_eq!(
                actual,
                expected.as_slice(),
                "reference interpreter disagrees with cpu_fn for {}\ncase: {case}\ninput_words: {:?}\ninput_bytes: {:02x?}\nexpected: {:02x?}\nactual: {:02x?}",
                spec.id,
                &input_words[case * arity..case * arity + arity],
                &input_bytes[start..end],
                expected,
                actual
            );
        }
    }
}

fn primitive_program(op_id: &'static str, arity: usize) -> Program {
    let idx = Expr::gid_x();
    let base = if arity == 1 {
        idx.clone()
    } else {
        Expr::BinOp {
            op: BinOp::Mul,
            left: Box::new(idx.clone()),
            right: Box::new(Expr::u32(arity as u32)),
        }
    };
    let args = (0..arity)
        .map(|offset| {
            let index = if offset == 0 {
                base.clone()
            } else {
                Expr::BinOp {
                    op: BinOp::Add,
                    left: Box::new(base.clone()),
                    right: Box::new(Expr::u32(offset as u32)),
                }
            };
            Expr::load("input", index)
        })
        .collect();

    Program::new(
        vec![
            BufferDecl::read("input", 0, DataType::U32),
            BufferDecl::read_write("out", 1, DataType::U32),
        ],
        [64, 1, 1],
        vec![Node::if_then(
            Expr::lt(idx.clone(), Expr::buf_len("out")),
            vec![Node::store(
                "out",
                idx,
                Expr::Call {
                    op_id: op_id.to_string(),
                    args,
                },
            )],
        )],
    )
}

fn witnessed_words(op_id: &str, arity: usize, cases: usize) -> Vec<u32> {
    let mut seed = fnv1a32(op_id.as_bytes()) ^ 0xA53C_9E17;
    let mut words = Vec::with_capacity(arity * cases);
    let boundaries = [0, 1, 2, 31, 32, 33, 0x7FFF_FFFF, 0x8000_0000, u32::MAX];
    for case in 0..cases {
        for arg in 0..arity {
            let boundary = boundaries[(case + arg) % boundaries.len()];
            seed = seed.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
            let mixed = seed.rotate_left(((case + arg) & 31) as u32);
            words.push(if case % 5 == 0 {
                boundary
            } else {
                mixed ^ boundary
            });
        }
    }
    words
}

fn words_to_bytes(words: &[u32]) -> Vec<u8> {
    words.iter().flat_map(|word| word.to_le_bytes()).collect()
}

fn fnv1a32(bytes: &[u8]) -> u32 {
    let mut hash = 0x811C_9DC5u32;
    for byte in bytes {
        hash ^= u32::from(*byte);
        hash = hash.wrapping_mul(0x0100_0193);
    }
    hash
}