vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Wire format ↔ IR semantic equivalence enforcer.
//!
//! Every IR program must round-trip through the wire codec bit-exactly
//! (invariant I4: `from_wire(to_wire(p)).unwrap() == p`) AND the decoded
//! program must produce byte-for-byte identical backend dispatch output
//! as the original, proving the codec is lossless across both the
//! structural and the semantic surface.

use vyre::ir::{BinOp, BufferDecl, DataType, Expr, Node, Program};

const WIRE_FORMAT_EQ_CASES_PER_OP: usize = 128;

/// A single violation of wire-format ↔ IR semantic equivalence.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WireFormatEquivViolation {
    /// Human-readable operation name, such as `Add` or `BitXor`.
    pub op_name: String,
    /// Input bytes associated with this witness pair.
    pub input_bytes: Vec<u8>,
    /// Hex dump of the original program dispatch output, or `None` if it failed.
    pub original_output: Option<String>,
    /// Hex dump of the round-tripped program dispatch output, or `None` if it failed.
    pub roundtripped_output: Option<String>,
    /// Actionable error message for the violation.
    pub message: String,
}

/// Report from running the wire-format equivalence enforcer.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WireFormatEquivReport {
    /// All detected structural, codec, or backend-output violations.
    pub violations: Vec<WireFormatEquivViolation>,
    /// Number of op/input witnesses tested.
    pub tested: usize,
    /// Number of op/input witnesses that passed both structural and semantic checks.
    pub passed: usize,
}

struct WireCase {
    op_name: &'static str,
    binop: BinOp,
}

fn wire_cases() -> [WireCase; 10] {
    [
        WireCase {
            op_name: "Add",
            binop: BinOp::Add,
        },
        WireCase {
            op_name: "Sub",
            binop: BinOp::Sub,
        },
        WireCase {
            op_name: "Mul",
            binop: BinOp::Mul,
        },
        WireCase {
            op_name: "Div",
            binop: BinOp::Div,
        },
        WireCase {
            op_name: "BitAnd",
            binop: BinOp::BitAnd,
        },
        WireCase {
            op_name: "BitOr",
            binop: BinOp::BitOr,
        },
        WireCase {
            op_name: "BitXor",
            binop: BinOp::BitXor,
        },
        WireCase {
            op_name: "Eq",
            binop: BinOp::Eq,
        },
        WireCase {
            op_name: "Ne",
            binop: BinOp::Ne,
        },
        WireCase {
            op_name: "Shl",
            binop: BinOp::Shl,
        },
    ]
}

/// Enforce invariant I4 and backend semantic equivalence over canonical IR programs.
///
/// Each canonical binary operation is encoded with deterministic non-zero
/// witnesses, decoded through `Program::from_wire`, compared structurally, and
/// dispatched through the supplied backend on both the original and decoded
/// wire blobs.
#[inline]
pub(crate) fn enforce_wire_format_equivalence(
    backend: &dyn vyre::VyreBackend,
) -> WireFormatEquivReport {
    let mut violations = Vec::new();
    let mut tested = 0usize;
    let mut passed = 0usize;
    for (case_index, case) in wire_cases().into_iter().enumerate() {
        for witness in 0..WIRE_FORMAT_EQ_CASES_PER_OP {
            tested += 1;
            let seed = ((case_index as u64) << 32) | witness as u64;
            let (a, b) = u32_pair_from_seed(seed);
            let input_bytes = u32s_to_bytes([a, b]);
            let program = match build_ir_program(case.binop.clone(), a, b) {
                Ok(program) => program,
                Err(err) => {
                    violations.push(codec_violation(
                        &case,
                        &input_bytes,
                        format!("failed to build canonical {} program: {err}", case.op_name),
                    ));
                    continue;
                }
            };

            match check_case(backend, &case, &program, &input_bytes, witness) {
                None => passed += 1,
                Some(violation) => violations.push(violation),
            }
        }
    }

    WireFormatEquivReport {
        violations,
        tested,
        passed,
    }
}

fn check_case(
    backend: &dyn vyre::VyreBackend,
    case: &WireCase,
    program: &Program,
    input_bytes: &[u8],
    witness: usize,
) -> Option<WireFormatEquivViolation> {
    let original_bytes = match program.to_wire() {
        Ok(bytes) => bytes,
        Err(err) => {
            return Some(codec_violation(
                case,
                input_bytes,
                format!("to_wire failed for witness {witness}: {err}. Fix: serialize a valid IR Program."),
            ));
        }
    };

    let roundtripped = match Program::from_wire(&original_bytes) {
        Ok(program) => program,
        Err(err) => {
            return Some(codec_violation(
                case,
                input_bytes,
                format!("from_wire failed for witness {witness}: {err}. Fix: keep Program::to_wire and Program::from_wire compatible."),
            ));
        }
    };

    if &roundtripped != program {
        return Some(codec_violation(
            case,
            input_bytes,
            format!(
                "wire round-trip changed IR structure for witness {witness}. Fix: preserve every Program field in the VIR0 codec."
            ),
        ));
    }

    if let Err(err) = roundtripped.to_wire() {
        return Some(codec_violation(
            case,
            input_bytes,
            format!("re-encoding round-tripped IR failed for witness {witness}: {err}. Fix: decode only encodable Program values."),
        ));
    }

    let original_output = dispatch_exact(backend, program, &[input_bytes.to_vec()], 4);
    let roundtripped_output = dispatch_exact(backend, &roundtripped, &[input_bytes.to_vec()], 4);

    match (&original_output, &roundtripped_output) {
        (Ok(original), Ok(roundtripped)) if original == roundtripped => None,
        _ => Some(WireFormatEquivViolation {
            op_name: case.op_name.to_string(),
            input_bytes: input_bytes.to_vec(),
            original_output: original_output.as_ref().ok().map(|bytes| hex_dump(bytes)),
            roundtripped_output: roundtripped_output.as_ref().ok().map(|bytes| hex_dump(bytes)),
            message: format!(
                "wire-format semantic divergence for {} witness {}: original_path={}, roundtripped_path={}. Fix: make Program::from_wire(to_wire(p)) preserve backend-visible semantics.",
                case.op_name,
                witness,
                result_label(&original_output),
                result_label(&roundtripped_output),
            ),
        }),
    }
}

fn dispatch_exact(
    backend: &dyn vyre::VyreBackend,
    program: &Program,
    inputs: &[Vec<u8>],
    output_size: usize,
) -> Result<Vec<u8>, vyre::BackendError> {
    let program = program_with_output_size(program, output_size);
    let mut outputs = backend.dispatch(&program, inputs, &vyre::DispatchConfig::default())?;
    if outputs.is_empty() {
        return Err(vyre::BackendError::new(
            "backend returned zero output buffers. Fix: return the wire-format probe output as outputs[0].",
        ));
    }
    let output = outputs.remove(0);
    if output.len() != output_size {
        return Err(vyre::BackendError::new(format!(
            "backend returned {} bytes, expected {output_size}. Fix: size the first output buffer from the wire-format probe output declaration.",
            output.len()
        )));
    }
    Ok(output)
}

fn program_with_output_size(program: &Program, output_size: usize) -> Program {
    let mut buffers = program.buffers().to_vec();
    for buffer in &mut buffers {
        if buffer.access == vyre::ir::BufferAccess::ReadWrite {
            buffer.is_output = true;
            buffer.count = output_size.div_ceil(4).try_into().unwrap_or(u32::MAX);
            break;
        }
    }
    Program::new(buffers, program.workgroup_size(), program.entry().to_vec())
}

fn codec_violation(
    case: &WireCase,
    input_bytes: &[u8],
    message: String,
) -> WireFormatEquivViolation {
    WireFormatEquivViolation {
        op_name: case.op_name.to_string(),
        input_bytes: input_bytes.to_vec(),
        original_output: None,
        roundtripped_output: None,
        message,
    }
}

fn result_label(result: &Result<Vec<u8>, vyre::BackendError>) -> String {
    match result {
        Ok(bytes) => format!("ok({})", hex_dump(bytes)),
        Err(err) => err.to_string(),
    }
}

fn build_ir_program(binop: BinOp, a: u32, b: u32) -> Result<Program, String> {
    Ok(Program::new(
        vec![BufferDecl::read_write("results", 0, DataType::U32)],
        [1, 1, 1],
        vec![
            Node::store("results", Expr::u32(0), binary_expr(binop, a, b)?),
            Node::Return,
        ],
    ))
}

fn binary_expr(binop: BinOp, a: u32, b: u32) -> Result<Expr, String> {
    let left = Expr::u32(a);
    let right = Expr::u32(b);
    let expr = match binop {
        BinOp::Add => Expr::add(left, right),
        BinOp::Sub => Expr::sub(left, right),
        BinOp::Mul => Expr::mul(left, right),
        BinOp::Div => Expr::div(left, right),
        BinOp::BitAnd => Expr::bitand(left, right),
        BinOp::BitOr => Expr::bitor(left, right),
        BinOp::BitXor => Expr::bitxor(left, right),
        BinOp::Eq => Expr::eq(left, right),
        BinOp::Ne => Expr::ne(left, right),
        BinOp::Shl => Expr::shl(left, right),
        _ => {
            return Err(format!(
                "unsupported canonical binary operation {binop:?}. Fix: keep the wire_format_eq catalog limited to explicitly handled BinOp variants."
            ));
        }
    };
    Ok(expr)
}

fn u32_pair_from_seed(seed: u64) -> (u32, u32) {
    let a = split_mix_u32(seed);
    let b = split_mix_u32(seed.wrapping_add(0x9e37_79b9_7f4a_7c15));
    (((a & 0x7fff) + 1), ((b & 0x7fff) + 1))
}

fn split_mix_u32(seed: u64) -> u32 {
    let mut z = seed.wrapping_add(0x9e37_79b9_7f4a_7c15);
    z = (z ^ (z >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
    z = (z ^ (z >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
    z = z ^ (z >> 31);
    ((z >> 32) as u32).wrapping_add(z as u32)
}

fn u32s_to_bytes(values: [u32; 2]) -> Vec<u8> {
    values
        .into_iter()
        .flat_map(u32::to_le_bytes)
        .collect::<Vec<_>>()
}

fn hex_dump(bytes: &[u8]) -> String {
    let mut out = String::with_capacity(bytes.len() * 2);
    for byte in bytes {
        use core::fmt::Write;
        let _ = write!(&mut out, "{byte:02x}");
    }
    out
}

/// Registry entry for `wire_format_equivalence` enforcement.
pub struct WireFormatEquivalenceEnforcer;

impl crate::enforce::EnforceGate for WireFormatEquivalenceEnforcer {
    fn id(&self) -> &'static str {
        "wire_format_equivalence"
    }

    fn name(&self) -> &'static str {
        "wire_format_equivalence"
    }

    fn run(&self, ctx: &crate::enforce::EnforceCtx<'_>) -> Vec<crate::enforce::Finding> {
        let Some(backend) = ctx.backend else {
            return vec![crate::enforce::aggregate_finding(self.id(), vec!["wire_format_equivalence: backend is required. Fix: provide a VyreBackend in EnforceCtx.".to_string()])];
        };
        let report = enforce_wire_format_equivalence(backend);
        let messages = report
            .violations
            .into_iter()
            .map(|violation| violation.message)
            .collect::<Vec<_>>();
        crate::enforce::finding_result(self.id(), messages)
    }
}

/// Auto-registered `wire_format_equivalence` enforcer.
pub const REGISTERED: WireFormatEquivalenceEnforcer = WireFormatEquivalenceEnforcer;

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn canonical_catalog_has_ten_entries() {
        assert_eq!(wire_cases().len(), 10);
    }

    #[test]
    fn prng_is_deterministic_and_bounded() {
        let first = u32_pair_from_seed(42);
        let second = u32_pair_from_seed(42);
        assert_eq!(first, second);
        assert!((1..=32_768).contains(&first.0));
        assert!((1..=32_768).contains(&first.1));
    }

    #[test]
    fn canonical_program_round_trips_structurally() {
        let program =
            build_ir_program(BinOp::Add, 7, 9).expect("canonical test program should build");
        let bytes = program
            .to_wire()
            .expect("canonical test program should encode");
        let decoded = Program::from_wire(&bytes).expect("canonical test program should decode");
        assert_eq!(decoded, program);
    }
}