vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
use vyre::ir::{BufferDecl, DataType as IrDataType, Expr, Node, Program};
use vyre_conform::{
    reference::interp,
    registry,
    spec::value::Value,
    types::{DataType as SpecDataType, OpSpec},
};

const WITNESSED_CASES: usize = 10_000;

/// Verifies every registered OpSpec discovered from the registry agrees with its cpu_fn oracle on exhaustive u8 inputs where tractable and deterministic u32 witnesses, closing reference-oracle parity corruption.
#[test]
fn reference_run_bit_matches_cpu_fn_for_registered_specs() {
    let specs = registry::all_specs();
    assert!(
        !specs.is_empty(),
        "Fix: registry::all_specs must expose the registered OpSpec set"
    );

    for spec in &specs {
        run_exhaustive_u8_when_tractable(spec);
        run_witnessed_u32_cases(spec);
    }
}

fn run_exhaustive_u8_when_tractable(spec: &OpSpec) {
    if !supports_exhaustive_u8(spec) {
        return;
    }
    let arity = spec.signature.inputs.len();
    let mut cases = Vec::new();
    if arity == 1 {
        for a in u8::MIN..=u8::MAX {
            cases.push(vec![u32::from(a)]);
        }
    } else {
        for a in u8::MIN..=u8::MAX {
            for b in u8::MIN..=u8::MAX {
                cases.push(vec![u32::from(a), u32::from(b)]);
            }
        }
    }
    assert_fixed_width_cases(spec, &cases, "exhaustive_u8");
}

fn run_witnessed_u32_cases(spec: &OpSpec) {
    if is_fixed_width_signature(spec) {
        let arity = spec.signature.inputs.len();
        let mut seed = seed_from_op_id(spec.id);
        let cases = (0..WITNESSED_CASES)
            .map(|case| {
                (0..arity)
                    .map(|arg| witnessed_word(&mut seed, case, arg))
                    .collect::<Vec<_>>()
            })
            .collect::<Vec<_>>();
        assert_fixed_width_cases(spec, &cases, "witnessed_u32");
        return;
    }

    let mut seed = seed_from_op_id(spec.id);
    for case in 0..WITNESSED_CASES {
        let input = witnessed_variable_input(spec, &mut seed, case);
        assert_single_case(spec, &input, "witnessed_u32", case);
    }
}

fn supports_exhaustive_u8(spec: &OpSpec) -> bool {
    let arity = spec.signature.inputs.len();
    (arity == 1 || arity == 2)
        && spec
            .signature
            .inputs
            .iter()
            .all(|ty| matches!(ty, SpecDataType::U32 | SpecDataType::I32))
        && fixed_output_size(&spec.signature.output).is_some()
}

fn is_fixed_width_signature(spec: &OpSpec) -> bool {
    spec.signature
        .inputs
        .iter()
        .all(|ty| fixed_input_type(ty).is_some())
        && fixed_output_size(&spec.signature.output).is_some()
}

fn assert_fixed_width_cases(spec: &OpSpec, cases: &[Vec<u32>], label: &str) {
    let program = fixed_width_program(spec, cases.len());
    let mut inputs = Vec::new();
    for (arg_index, ty) in spec.signature.inputs.iter().enumerate() {
        let mut bytes = Vec::new();
        for case in cases {
            append_word_as_type(&mut bytes, case[arg_index], ty);
        }
        inputs.push(Value::Bytes(bytes));
    }
    let output_size = fixed_output_size(&spec.signature.output).expect("checked by caller");
    inputs.push(Value::Bytes(vec![0; cases.len() * output_size]));

    let outputs = interp::run(&program, &inputs)
        .unwrap_or_else(|error| panic!("Fix: interp::run failed for {} {label}: {error}", spec.id));
    let Value::Bytes(actual) = &outputs[0] else {
        panic!("Fix: interp::run returned non-byte output for {}", spec.id);
    };

    let expected = cases
        .iter()
        .flat_map(|case| {
            let input = case
                .iter()
                .zip(&spec.signature.inputs)
                .flat_map(|(word, ty)| {
                    let mut bytes = Vec::new();
                    append_word_as_type(&mut bytes, *word, ty);
                    bytes
                })
                .collect::<Vec<_>>();
            (spec.cpu_fn)(&input)
        })
        .collect::<Vec<_>>();

    assert_eq!(
        actual, &expected,
        "Fix: reference interpreter disagrees with cpu_fn for {} on {label}",
        spec.id
    );
}

fn fixed_width_program(spec: &OpSpec, cases: usize) -> Program {
    let gid = Expr::gid_x();
    let args = spec
        .signature
        .inputs
        .iter()
        .enumerate()
        .map(|(index, _)| Expr::load(&format!("arg{index}"), gid.clone()))
        .collect::<Vec<_>>();
    let mut buffers = spec
        .signature
        .inputs
        .iter()
        .enumerate()
        .map(|(index, ty)| {
            BufferDecl::read(
                &format!("arg{index}"),
                index as u32,
                fixed_input_type(ty).expect("checked by caller"),
            )
        })
        .collect::<Vec<_>>();
    buffers.push(BufferDecl::read_write(
        "out",
        buffers.len() as u32,
        fixed_input_type(&spec.signature.output).expect("checked by caller"),
    ));
    Program::new(
        buffers,
        [64, 1, 1],
        vec![Node::if_then(
            Expr::lt(gid.clone(), Expr::u32(cases as u32)),
            vec![Node::store(
                "out",
                gid,
                Expr::Call {
                    op_id: spec.id.to_string(),
                    args,
                },
            )],
        )],
    )
}

fn assert_single_case(spec: &OpSpec, input: &[Vec<u8>], label: &str, case: usize) {
    let expected = (spec.cpu_fn)(&input.concat());
    let program = variable_width_program(spec, expected.len());
    let mut values = input.iter().cloned().map(Value::Bytes).collect::<Vec<_>>();
    values.push(Value::Bytes(vec![0; expected.len()]));
    let outputs = interp::run(&program, &values).unwrap_or_else(|error| {
        panic!(
            "Fix: interp::run failed for {} {label} case {case}: {error}",
            spec.id
        )
    });
    let Value::Bytes(actual) = &outputs[0] else {
        panic!("Fix: interp::run returned non-byte output for {}", spec.id);
    };
    assert_eq!(
        actual, &expected,
        "Fix: reference interpreter disagrees with cpu_fn for {} on {label} case {case}",
        spec.id
    );
}

fn variable_width_program(spec: &OpSpec, output_len: usize) -> Program {
    let args = spec
        .signature
        .inputs
        .iter()
        .enumerate()
        .map(|(index, _)| Expr::load(&format!("arg{index}"), Expr::u32(0)))
        .collect::<Vec<_>>();
    let mut buffers = spec
        .signature
        .inputs
        .iter()
        .enumerate()
        .map(|(index, ty)| {
            BufferDecl::read(
                &format!("arg{index}"),
                index as u32,
                input_type_or_bytes(ty),
            )
        })
        .collect::<Vec<_>>();
    buffers.push(BufferDecl::read_write(
        "out",
        buffers.len() as u32,
        output_type_or_bytes(&spec.signature.output, output_len),
    ));
    Program::new(
        buffers,
        [1, 1, 1],
        vec![Node::store(
            "out",
            Expr::u32(0),
            Expr::Call {
                op_id: spec.id.to_string(),
                args,
            },
        )],
    )
}

fn witnessed_variable_input(spec: &OpSpec, seed: &mut u32, case: usize) -> Vec<Vec<u8>> {
    spec.signature
        .inputs
        .iter()
        .enumerate()
        .map(|(arg, ty)| match ty {
            SpecDataType::Bytes | SpecDataType::Array { .. } | SpecDataType::Tensor => {
                let len = (witnessed_word(seed, case, arg) as usize) % 65;
                (0..len)
                    .map(|byte_index| witnessed_word(seed, case, arg + byte_index) as u8)
                    .collect()
            }
            _ => {
                let mut bytes = Vec::new();
                append_word_as_type(&mut bytes, witnessed_word(seed, case, arg), ty);
                bytes
            }
        })
        .collect()
}

fn append_word_as_type(bytes: &mut Vec<u8>, word: u32, ty: &SpecDataType) {
    match ty {
        SpecDataType::U32 | SpecDataType::I32 | SpecDataType::F32 => {
            bytes.extend_from_slice(&word.to_le_bytes());
        }
        SpecDataType::U64 | SpecDataType::Vec2U32 => {
            bytes.extend_from_slice(&word.to_le_bytes());
            bytes.extend_from_slice(&word.rotate_left(17).to_le_bytes());
        }
        SpecDataType::Vec4U32 => {
            for lane in 0..4 {
                bytes.extend_from_slice(&word.rotate_left(lane * 7).to_le_bytes());
            }
        }
        SpecDataType::F16 | SpecDataType::BF16 => {
            bytes.extend_from_slice(&(word as u16).to_le_bytes());
        }
        SpecDataType::Bytes | SpecDataType::Array { .. } | SpecDataType::Tensor => {}
        unsupported => {
            panic!("Fix: add reference parity input encoding for unsupported type {unsupported:?}")
        }
    }
}

fn fixed_input_type(ty: &SpecDataType) -> Option<IrDataType> {
    match ty {
        SpecDataType::U32 | SpecDataType::F32 => Some(IrDataType::U32),
        SpecDataType::I32 => Some(IrDataType::I32),
        SpecDataType::U64 => Some(IrDataType::U64),
        SpecDataType::Vec2U32 => Some(IrDataType::Vec2U32),
        SpecDataType::Vec4U32 => Some(IrDataType::Vec4U32),
        SpecDataType::F16
        | SpecDataType::BF16
        | SpecDataType::Bytes
        | SpecDataType::Array { .. }
        | SpecDataType::Tensor => None,
        unsupported => panic!(
            "Fix: add reference parity fixed-width mapping for unsupported type {unsupported:?}"
        ),
    }
}

fn fixed_output_size(ty: &SpecDataType) -> Option<usize> {
    fixed_input_type(ty).map(|ir| ir.min_bytes())
}

fn input_type_or_bytes(ty: &SpecDataType) -> IrDataType {
    fixed_input_type(ty).unwrap_or(IrDataType::Bytes)
}

fn output_type_or_bytes(ty: &SpecDataType, output_len: usize) -> IrDataType {
    if output_len == 0 {
        IrDataType::Bytes
    } else {
        fixed_input_type(ty).unwrap_or(IrDataType::Bytes)
    }
}

fn witnessed_word(seed: &mut u32, case: usize, arg: usize) -> u32 {
    const BOUNDARIES: [u32; 11] = [
        0,
        1,
        2,
        7,
        31,
        32,
        33,
        0x7FFF_FFFF,
        0x8000_0000,
        0xFFFF_FFFE,
        u32::MAX,
    ];
    *seed = seed.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
    let boundary = BOUNDARIES[(case + arg) % BOUNDARIES.len()];
    if case % 4 == 0 {
        boundary
    } else {
        seed.rotate_left(((case + arg) & 31) as u32) ^ boundary
    }
}

fn seed_from_op_id(op_id: &str) -> u32 {
    let mut hash = 0x811C_9DC5u32;
    for byte in op_id.bytes() {
        hash ^= u32::from(byte);
        hash = hash.wrapping_mul(0x0100_0193);
    }
    hash
}