vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Per-operation certification: parity + law verification.

use std::collections::BTreeMap;

use crate::enforce::enforcers::float_semantics::{FloatFinding, FloatGateStatus};
use crate::proof::algebra::checker::{verify_laws, verify_laws_witnessed};
use crate::proof::algebra::gpu_checker::verify_gpu_laws_witnessed;
use crate::spec::types::{DataType, OpSpec, ParityFailure};
use vyre_spec::Category;

use super::parity::{run_parity, run_streaming_parity};
use super::{CertificateStrength, OpOutcome, OpResult, MAX_RECORDED_PARITY_FAILURES};

pub(super) fn certify_op(
    backend: &dyn vyre::VyreBackend,
    spec: &OpSpec,
    strength: CertificateStrength,
) -> OpResult {
    if category_c_unsupported_by_backend(backend, spec) {
        return unsupported_op_result(spec);
    }

    if crate::enforce::enforcers::float_semantics::is_float_spec(spec) {
        let findings =
            crate::enforce::enforcers::float_semantics::enforce_float_semantics(backend, spec);
        if !findings.is_empty() {
            return float_gate_result(spec, findings);
        }
    }

    let parity = if strength == CertificateStrength::Legendary {
        run_streaming_parity(backend, spec)
    } else {
        run_parity(backend, spec)
    };
    if std::env::var_os("VYRE_CERTIFY_TRACE").is_some() {
        eprintln!("certify_op: parity done");
    }
    if parity.unsupported {
        return unsupported_op_result(spec);
    }

    let law = run_laws(backend, spec, strength);
    if std::env::var_os("VYRE_CERTIFY_TRACE").is_some() {
        eprintln!("certify_op: laws done");
    }
    let mut law_status: BTreeMap<String, (bool, u64)> = BTreeMap::new();
    let mut laws_failed = Vec::new();
    let mut law_cases = 0_u64;

    for result in law {
        law_cases = law_cases.saturating_add(result.cases_tested);
        let entry = law_status.entry(result.law_name).or_insert((true, 0));
        entry.1 = entry.1.saturating_add(result.cases_tested);
        if let Some(violation) = result.violation {
            entry.0 = false;
            laws_failed.push(violation);
        }
    }

    // Any declared law with zero cases tested is a failure.
    for law in &spec.laws {
        let name = law.name();
        let cases = law_status.get(name).map(|(_, c)| *c).unwrap_or(0);
        if cases == 0 {
            laws_failed.push(crate::spec::law::LawViolation {
                law: name.to_string(),
                op_id: spec.id.to_string(),
                a: 0,
                b: 0,
                c: 0,
                lhs: 0,
                rhs: 0,
                message: "Fix: law returned zero cases tested.".to_string(),
            });
        }
    }

    let laws_verified: Vec<_> = law_status
        .into_iter()
        .filter_map(|(law_name, (passed, cases))| {
            // A law is verified only if no violation was found AND
            // at least one checker actually ran cases for it.
            (passed && cases > 0).then_some(law_name)
        })
        .collect();

    let all_laws_verified = laws_verified.len() == spec.laws.len();
    let passed = parity.failures.is_empty() && laws_failed.is_empty() && all_laws_verified;
    let outcome = if passed {
        OpOutcome::Passed
    } else {
        OpOutcome::Failed
    };

    // Per audit L.1.33: record the DECLARED witness count from the
    // requested `CertificateStrength`, not `min(law_cases, declared)`.
    // Silently capping at `law_cases` (which is zero for law-less ops)
    // made a Standard certificate indistinguishable from a FastCheck
    // one in the persisted JSON. `cases_tested` below still reflects
    // actual evaluation volume.
    let witness_count = strength.witness_count();

    OpResult {
        id: spec.id.to_string(),
        archetype: archetype_for_spec(spec),
        outcome,
        parity_passed: parity.failures.is_empty(),
        laws_verified,
        laws_failed,
        parity_failures: parity
            .failures
            .into_iter()
            .take(MAX_RECORDED_PARITY_FAILURES)
            .collect(),
        cases_tested: parity.cases_tested.saturating_add(law_cases),
        witness_count,
    }
}

fn float_gate_result(spec: &OpSpec, findings: Vec<FloatFinding>) -> OpResult {
    let pending = findings
        .iter()
        .any(|finding| finding.status == FloatGateStatus::Pending);
    let outcome = if pending {
        OpOutcome::Pending
    } else {
        OpOutcome::Failed
    };
    OpResult {
        id: spec.id.to_string(),
        archetype: archetype_for_spec(spec),
        outcome,
        parity_passed: false,
        laws_verified: Vec::new(),
        laws_failed: Vec::new(),
        parity_failures: findings
            .into_iter()
            .map(|finding| ParityFailure {
                op_id: spec.id.to_string(),
                generator: format!("float-enforce-{}", finding.rule),
                input_label: format!("SPEC.md:{}", finding.spec_line),
                input: Vec::new(),
                gpu_output: Vec::new(),
                cpu_output: Vec::new(),
                message: finding.message,
                spec_version: spec.version,
                workgroup_size: 1,
            })
            .take(MAX_RECORDED_PARITY_FAILURES)
            .collect(),
        cases_tested: 0,
        witness_count: 0,
    }
}

fn run_laws(
    backend: &dyn vyre::VyreBackend,
    spec: &OpSpec,
    strength: CertificateStrength,
) -> Vec<crate::proof::algebra::LawResult> {
    let is_binary = spec.signature.inputs.len() == 2;
    let mut results = if strength == CertificateStrength::FastCheck {
        Vec::new()
    } else {
        verify_laws(spec.id, spec.cpu_fn, &spec.laws, is_binary)
    };
    let witness_count = strength.witness_count();
    results.extend(verify_laws_witnessed(
        spec.id,
        spec.cpu_fn,
        &spec.laws,
        is_binary,
        witness_count,
    ));
    let gpu_witness_count = match strength {
        CertificateStrength::FastCheck => witness_count.min(256),
        CertificateStrength::Standard | CertificateStrength::Legendary => witness_count,
    };
    results.extend(verify_gpu_laws_witnessed(backend, spec, gpu_witness_count));
    results
}

fn unsupported_op_result(spec: &OpSpec) -> OpResult {
    OpResult {
        id: spec.id.to_string(),
        archetype: archetype_for_spec(spec),
        outcome: OpOutcome::Unsupported,
        parity_passed: false,
        laws_verified: Vec::new(),
        laws_failed: Vec::new(),
        parity_failures: Vec::new(),
        cases_tested: 0,
        witness_count: 0,
    }
}

/// Derive the archetype for an OpSpec.
///
/// Uses the explicit `archetype` field from spec.toml when available.
/// Falls back to a signature-based heuristic for legacy ops that
/// haven't been classified yet.
pub(super) fn archetype_for_spec(spec: &OpSpec) -> String {
    if !spec.archetype.is_empty() {
        return spec.archetype.to_string();
    }
    // Heuristic fallback for ops without an explicit archetype.
    let inputs = &spec.signature.inputs;
    let output = &spec.signature.output;
    match (inputs.len(), inputs.first(), output) {
        (2, Some(DataType::U32), DataType::U32) => {
            if spec.id.contains("compare")
                || spec.id.contains("eq")
                || spec.id.contains("lt")
                || spec.id.contains("gt")
                || spec.id.contains("le")
                || spec.id.contains("ge")
                || spec.id.contains("ne")
                || spec.id.contains("select")
            {
                "binary-comparison".to_string()
            } else if spec.id.contains("add")
                || spec.id.contains("sub")
                || spec.id.contains("mul")
                || spec.id.contains("div")
                || spec.id.contains("mod")
                || spec.id.contains("min")
                || spec.id.contains("max")
                || spec.id.contains("clamp")
            {
                "binary-arithmetic".to_string()
            } else {
                "binary-bitwise".to_string()
            }
        }
        (1, Some(DataType::U32), DataType::U32) => {
            if spec.id.contains("abs") || spec.id.contains("negate") {
                "unary-arithmetic".to_string()
            } else if spec.id.contains("logical_not") {
                "unary-logical".to_string()
            } else {
                "unary-bitwise".to_string()
            }
        }
        (1, Some(DataType::Bytes), DataType::Bytes) => {
            if spec.id.contains("decode") {
                "decode-bytes-to-bytes".to_string()
            } else if spec.id.contains("compress") {
                "compression-bytes-to-bytes".to_string()
            } else {
                "tokenize-bytes".to_string()
            }
        }
        (1, Some(DataType::Bytes), DataType::U32) => "hash-bytes-to-u32".to_string(),
        _ => "unknown".to_string(),
    }
}

fn category_c_unsupported_by_backend(backend: &dyn vyre::VyreBackend, spec: &OpSpec) -> bool {
    match &spec.category {
        Category::C {
            backend_availability,
            ..
        } => !backend_availability.available(backend.id()),
        _ => false,
    }
}

pub(super) fn is_unsupported_by_backend_message(message: &str) -> bool {
    let lower = message.to_ascii_lowercase();
    message.contains("UnsupportedByBackend") || lower.contains("unsupported by backend")
}