vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Probe generation for reference-trust enforcement.

use crate::spec::{LabeledProbe, ReferenceTrustProbes};
use crate::{DataType, OpSignature};

/// Deterministic seed used for reference-trust pseudo-random probes.
pub const REFERENCE_TRUST_SEED: u64 = 0x5659_5245_5452_5553;

/// Generate byte-oriented probes for independent reference comparisons.
#[inline]
pub fn byte_reference_inputs(count: usize) -> Vec<Vec<u8>> {
    let mut rng = ProbeRng::new(REFERENCE_TRUST_SEED);
    let mut inputs: Vec<Vec<u8>> = Vec::with_capacity(count.max(7));
    inputs.extend([
        Vec::new(),
        vec![0],
        vec![1],
        vec![u8::MAX],
        vec![0xAA; 32],
        vec![0x55; 32],
        b"123456789".to_vec(),
    ]);

    while inputs.len() < count {
        let len = (rng.next_u32() as usize % 257).max(1);
        let mut input = Vec::with_capacity(len);
        for _ in 0..len {
            input.push((rng.next_u32() & 0xFF) as u8);
        }
        inputs.push(input);
    }
    inputs.truncate(count);
    inputs
}

/// Generate boundary-focused byte inputs for one operation signature.
#[inline]
pub fn boundary_inputs(signature: &OpSignature) -> Vec<(String, Vec<u8>)> {
    if signature.inputs.iter().all(is_numeric) {
        return numeric_boundary_inputs(signature);
    }
    if signature.inputs == [DataType::Bytes] {
        return byte_lengths()
            .iter()
            .map(|len| {
                let input = (0..*len)
                    .map(|idx| {
                        if idx % 2 == 0 {
                            0xAA
                        } else {
                            ((idx as u8).wrapping_mul(37)) ^ 0x55
                        }
                    })
                    .collect();
                (format!("bytes-len-{len}"), input)
            })
            .collect();
    }
    vec![(
        "minimum-width-zero".to_string(),
        vec![0; signature.min_input_bytes()],
    )]
}

/// Generate deterministic property probes for output-shape assertions.
#[inline]
pub fn property_inputs(signature: &OpSignature) -> Vec<(String, Vec<u8>)> {
    let mut probes = boundary_inputs(signature);
    for (idx, input) in byte_reference_inputs(16).into_iter().enumerate() {
        if signature.inputs == [DataType::Bytes] {
            probes.push((format!("property-random-{idx}"), input));
        }
    }
    probes
}

/// Generate the complete probe set consumed by reference-trust enforcement.
#[must_use]
#[inline]
pub fn reference_trust_probes(signature: &OpSignature) -> ReferenceTrustProbes {
    ReferenceTrustProbes::new(
        boundary_inputs(signature)
            .into_iter()
            .map(|(label, input)| LabeledProbe::new(label, input))
            .collect(),
        property_inputs(signature)
            .into_iter()
            .map(|(label, input)| LabeledProbe::new(label, input))
            .collect(),
        byte_reference_inputs(32),
    )
}

fn numeric_boundary_inputs(signature: &OpSignature) -> Vec<(String, Vec<u8>)> {
    let values = numeric_values();
    let arity_mixed_extra = if signature.inputs.len() > 1 {
        values.len().saturating_sub(signature.inputs.len() - 1)
    } else {
        0
    };
    let mut out = Vec::with_capacity(values.len() + arity_mixed_extra + 1);
    let arity = signature.inputs.len();
    if arity == 0 {
        out.push(("empty".to_string(), Vec::new()));
        return out;
    }

    for (idx, value) in values.iter().enumerate() {
        let mut bytes = Vec::with_capacity(signature.min_input_bytes());
        for ty in &signature.inputs {
            append_numeric(&mut bytes, ty, *value);
        }
        out.push((format!("numeric-all-{idx}"), bytes));
    }

    if arity > 1 {
        for (idx, window) in values.windows(arity).enumerate() {
            let mut bytes = Vec::with_capacity(signature.min_input_bytes());
            for (ty, value) in signature.inputs.iter().zip(window.iter().copied()) {
                append_numeric(&mut bytes, ty, value);
            }
            out.push((format!("numeric-mixed-{idx}"), bytes));
        }
    }
    out
}

fn append_numeric(out: &mut Vec<u8>, ty: &DataType, value: u64) {
    match ty {
        DataType::U32 | DataType::I32 | DataType::Bool | DataType::F32 => {
            out.extend_from_slice(&(value as u32).to_le_bytes());
        }
        DataType::U64 | DataType::F64 => out.extend_from_slice(&value.to_le_bytes()),
        DataType::Vec2U32 => {
            out.extend_from_slice(&(value as u32).to_le_bytes());
            out.extend_from_slice((!(value as u32)).to_le_bytes().as_slice());
        }
        DataType::Vec4U32 => {
            let value = value as u32;
            for lane in [value, !value, value.rotate_left(7), value.rotate_right(11)] {
                out.extend_from_slice(&lane.to_le_bytes());
            }
        }
        DataType::F16 | DataType::BF16 => out.extend_from_slice(&(value as u16).to_le_bytes()),
        DataType::Bytes | DataType::Array { .. } | DataType::Tensor => {}
    }
}

fn numeric_values() -> &'static [u64] {
    &[
        0,
        1,
        u32::MAX as u64,
        u32::MAX as u64 - 1,
        i32::MIN as u32 as u64,
        i32::MAX as u64,
        0xAAAA_AAAA,
        0x5555_5555,
        0x8000_0000,
        0x7FFF_FFFF,
        u64::MAX,
        u64::MAX - 1,
    ]
}

fn byte_lengths() -> &'static [usize] {
    &[
        0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256,
    ]
}

fn is_numeric(ty: &DataType) -> bool {
    matches!(
        ty,
        DataType::U32
            | DataType::I32
            | DataType::U64
            | DataType::Vec2U32
            | DataType::Vec4U32
            | DataType::Bool
            | DataType::F16
            | DataType::BF16
            | DataType::F32
            | DataType::F64
    )
}

struct ProbeRng {
    state: u64,
}

impl ProbeRng {
    fn new(seed: u64) -> Self {
        Self { state: seed }
    }

    fn next_u32(&mut self) -> u32 {
        self.state = self
            .state
            .wrapping_mul(6_364_136_223_846_793_005)
            .wrapping_add(1_442_695_040_888_963_407);
        (self.state >> 32) as u32
    }
}