use std::collections::BTreeMap;
use crate::enforce::enforcers::float_semantics::{FloatFinding, FloatGateStatus};
use crate::proof::algebra::checker::{verify_laws, verify_laws_witnessed};
use crate::proof::algebra::gpu_checker::verify_gpu_laws_witnessed;
use crate::spec::types::{DataType, OpSpec, ParityFailure};
use vyre_spec::Category;
use super::parity::{run_parity, run_streaming_parity};
use super::{CertificateStrength, OpOutcome, OpResult, MAX_RECORDED_PARITY_FAILURES};
pub(super) fn certify_op(
backend: &dyn vyre::VyreBackend,
spec: &OpSpec,
strength: CertificateStrength,
) -> OpResult {
if category_c_unsupported_by_backend(backend, spec) {
return unsupported_op_result(spec);
}
if crate::enforce::enforcers::float_semantics::is_float_spec(spec) {
let findings =
crate::enforce::enforcers::float_semantics::enforce_float_semantics(backend, spec);
if !findings.is_empty() {
return float_gate_result(spec, findings);
}
}
let parity = if strength == CertificateStrength::Legendary {
run_streaming_parity(backend, spec)
} else {
run_parity(backend, spec)
};
if std::env::var_os("VYRE_CERTIFY_TRACE").is_some() {
eprintln!("certify_op: parity done");
}
if parity.unsupported {
return unsupported_op_result(spec);
}
let law = run_laws(backend, spec, strength);
if std::env::var_os("VYRE_CERTIFY_TRACE").is_some() {
eprintln!("certify_op: laws done");
}
let mut law_status: BTreeMap<String, (bool, u64)> = BTreeMap::new();
let mut laws_failed = Vec::new();
let mut law_cases = 0_u64;
for result in law {
law_cases = law_cases.saturating_add(result.cases_tested);
let entry = law_status.entry(result.law_name).or_insert((true, 0));
entry.1 = entry.1.saturating_add(result.cases_tested);
if let Some(violation) = result.violation {
entry.0 = false;
laws_failed.push(violation);
}
}
for law in &spec.laws {
let name = law.name();
let cases = law_status.get(name).map(|(_, c)| *c).unwrap_or(0);
if cases == 0 {
laws_failed.push(crate::spec::law::LawViolation {
law: name.to_string(),
op_id: spec.id.to_string(),
a: 0,
b: 0,
c: 0,
lhs: 0,
rhs: 0,
message: "Fix: law returned zero cases tested.".to_string(),
});
}
}
let laws_verified: Vec<_> = law_status
.into_iter()
.filter_map(|(law_name, (passed, cases))| {
(passed && cases > 0).then_some(law_name)
})
.collect();
let all_laws_verified = laws_verified.len() == spec.laws.len();
let passed = parity.failures.is_empty() && laws_failed.is_empty() && all_laws_verified;
let outcome = if passed {
OpOutcome::Passed
} else {
OpOutcome::Failed
};
let witness_count = strength.witness_count();
OpResult {
id: spec.id.to_string(),
archetype: archetype_for_spec(spec),
outcome,
parity_passed: parity.failures.is_empty(),
laws_verified,
laws_failed,
parity_failures: parity
.failures
.into_iter()
.take(MAX_RECORDED_PARITY_FAILURES)
.collect(),
cases_tested: parity.cases_tested.saturating_add(law_cases),
witness_count,
}
}
fn float_gate_result(spec: &OpSpec, findings: Vec<FloatFinding>) -> OpResult {
let pending = findings
.iter()
.any(|finding| finding.status == FloatGateStatus::Pending);
let outcome = if pending {
OpOutcome::Pending
} else {
OpOutcome::Failed
};
OpResult {
id: spec.id.to_string(),
archetype: archetype_for_spec(spec),
outcome,
parity_passed: false,
laws_verified: Vec::new(),
laws_failed: Vec::new(),
parity_failures: findings
.into_iter()
.map(|finding| ParityFailure {
op_id: spec.id.to_string(),
generator: format!("float-enforce-{}", finding.rule),
input_label: format!("SPEC.md:{}", finding.spec_line),
input: Vec::new(),
gpu_output: Vec::new(),
cpu_output: Vec::new(),
message: finding.message,
spec_version: spec.version,
workgroup_size: 1,
})
.take(MAX_RECORDED_PARITY_FAILURES)
.collect(),
cases_tested: 0,
witness_count: 0,
}
}
fn run_laws(
backend: &dyn vyre::VyreBackend,
spec: &OpSpec,
strength: CertificateStrength,
) -> Vec<crate::proof::algebra::LawResult> {
let is_binary = spec.signature.inputs.len() == 2;
let mut results = if strength == CertificateStrength::FastCheck {
Vec::new()
} else {
verify_laws(spec.id, spec.cpu_fn, &spec.laws, is_binary)
};
let witness_count = strength.witness_count();
results.extend(verify_laws_witnessed(
spec.id,
spec.cpu_fn,
&spec.laws,
is_binary,
witness_count,
));
let gpu_witness_count = match strength {
CertificateStrength::FastCheck => witness_count.min(256),
CertificateStrength::Standard | CertificateStrength::Legendary => witness_count,
};
results.extend(verify_gpu_laws_witnessed(backend, spec, gpu_witness_count));
results
}
fn unsupported_op_result(spec: &OpSpec) -> OpResult {
OpResult {
id: spec.id.to_string(),
archetype: archetype_for_spec(spec),
outcome: OpOutcome::Unsupported,
parity_passed: false,
laws_verified: Vec::new(),
laws_failed: Vec::new(),
parity_failures: Vec::new(),
cases_tested: 0,
witness_count: 0,
}
}
pub(super) fn archetype_for_spec(spec: &OpSpec) -> String {
if !spec.archetype.is_empty() {
return spec.archetype.to_string();
}
let inputs = &spec.signature.inputs;
let output = &spec.signature.output;
match (inputs.len(), inputs.first(), output) {
(2, Some(DataType::U32), DataType::U32) => {
if spec.id.contains("compare")
|| spec.id.contains("eq")
|| spec.id.contains("lt")
|| spec.id.contains("gt")
|| spec.id.contains("le")
|| spec.id.contains("ge")
|| spec.id.contains("ne")
|| spec.id.contains("select")
{
"binary-comparison".to_string()
} else if spec.id.contains("add")
|| spec.id.contains("sub")
|| spec.id.contains("mul")
|| spec.id.contains("div")
|| spec.id.contains("mod")
|| spec.id.contains("min")
|| spec.id.contains("max")
|| spec.id.contains("clamp")
{
"binary-arithmetic".to_string()
} else {
"binary-bitwise".to_string()
}
}
(1, Some(DataType::U32), DataType::U32) => {
if spec.id.contains("abs") || spec.id.contains("negate") {
"unary-arithmetic".to_string()
} else if spec.id.contains("logical_not") {
"unary-logical".to_string()
} else {
"unary-bitwise".to_string()
}
}
(1, Some(DataType::Bytes), DataType::Bytes) => {
if spec.id.contains("decode") {
"decode-bytes-to-bytes".to_string()
} else if spec.id.contains("compress") {
"compression-bytes-to-bytes".to_string()
} else {
"tokenize-bytes".to_string()
}
}
(1, Some(DataType::Bytes), DataType::U32) => "hash-bytes-to-u32".to_string(),
_ => "unknown".to_string(),
}
}
fn category_c_unsupported_by_backend(backend: &dyn vyre::VyreBackend, spec: &OpSpec) -> bool {
match &spec.category {
Category::C {
backend_availability,
..
} => !backend_availability.available(backend.id()),
_ => false,
}
}
pub(super) fn is_unsupported_by_backend_message(message: &str) -> bool {
let lower = message.to_ascii_lowercase();
message.contains("UnsupportedByBackend") || lower.contains("unsupported by backend")
}