use crate::binding::{BindingRegistry, ImplStatus, KernelBinding};
use crate::schema::{AppliesTo, Contract, ObligationType};
#[allow(clippy::too_many_lines)]
pub fn generate_wired_probar_tests(
contract: &Contract,
contract_file: &str,
binding: &BindingRegistry,
) -> String {
let bindings: Vec<&KernelBinding> = binding
.bindings_for(contract_file)
.into_iter()
.filter(|b| b.status != ImplStatus::NotImplemented)
.collect();
if bindings.is_empty() && contract.proof_obligations.is_empty() {
return format!(
"// No wired tests: no implemented bindings \
for {contract_file}\n"
);
}
let mut out = String::new();
emit_header(&mut out, contract_file);
emit_imports(&mut out, &bindings);
if contract.proof_obligations.is_empty() {
out.push_str("// No proof obligations in this contract.\n");
return out;
}
out.push_str("proptest! {\n");
for (i, ob) in contract.proof_obligations.iter().enumerate() {
let fn_name = obligation_fn_name(ob, i);
out.push_str(&format!(
" /// Obligation: {} ({})\n",
ob.property, ob.obligation_type
));
if let Some(ref formal) = ob.formal {
out.push_str(&format!(" /// Formal: {formal}\n"));
}
if ob.applies_to == Some(AppliesTo::Simd) {
emit_ignored_simd_test(&mut out, &fn_name);
continue;
}
if let Some(b) = bindings.first() {
generate_wired_obligation_test(&mut out, ob, &fn_name, b);
} else {
emit_ignored_no_binding(&mut out, &fn_name);
}
}
out.push_str("}\n");
out
}
fn emit_header(out: &mut String, contract_file: &str) {
let hash = simple_hash(contract_file);
out.push_str(&format!("// CONTRACT: {contract_file}\n"));
out.push_str(&format!("// HASH: sha256:{hash}\n"));
out.push_str("// Generated by: pv probar --binding\n");
out.push_str(
"// DO NOT EDIT — regenerate with \
`pv probar --binding`\n\n",
);
}
fn emit_imports(out: &mut String, bindings: &[&KernelBinding]) {
out.push_str("use proptest::prelude::*;\n");
let mut imports: Vec<String> = Vec::new();
for b in bindings {
if let Some(ref mp) = b.module_path {
let import = mp.clone();
if !imports.contains(&import) {
imports.push(import);
}
}
}
for imp in &imports {
out.push_str(&format!("use {imp};\n"));
}
out.push('\n');
}
fn emit_ignored_simd_test(out: &mut String, fn_name: &str) {
out.push_str(" #[test]\n");
out.push_str(
" #[ignore = \"SIMD equivalence \
— trueno domain\"]\n",
);
out.push_str(&format!(" fn {fn_name}(\n"));
out.push_str(
" _x in proptest::collection::vec(\
-100.0f32..100.0, 1..32usize)\n",
);
out.push_str(" ) {\n");
out.push_str(
" // SIMD equivalence testing is \
trueno's responsibility\n",
);
out.push_str(" }\n\n");
}
fn emit_ignored_no_binding(out: &mut String, fn_name: &str) {
out.push_str(" #[test]\n");
out.push_str(" #[ignore = \"no binding available\"]\n");
out.push_str(&format!(" fn {fn_name}(\n"));
out.push_str(
" _x in proptest::collection::vec(\
-100.0f32..100.0, 1..32usize)\n",
);
out.push_str(" ) {\n");
out.push_str(
" // No binding — implement and \
update binding.yaml\n",
);
out.push_str(" }\n\n");
}
fn generate_wired_obligation_test(
out: &mut String,
ob: &crate::schema::ProofObligation,
fn_name: &str,
binding: &KernelBinding,
) {
let tol = ob
.tolerance
.map_or("1e-6".to_string(), |t| format!("{t:e}"));
let kernel_type = classify_kernel(binding);
out.push_str(" #[test]\n");
out.push_str(&format!(" fn {fn_name}(\n"));
match ob.obligation_type {
ObligationType::Invariant => {
generate_wired_invariant(out, ob, &tol, &kernel_type);
}
ObligationType::Bound => {
generate_wired_bound(out, ob, &kernel_type);
}
ObligationType::Monotonicity => {
generate_wired_monotonicity(out, ob, &kernel_type);
}
ObligationType::Idempotency => {
generate_wired_idempotency(out, ob, &tol, &kernel_type);
}
ObligationType::Equivalence => {
generate_wired_equivalence(out, ob, &tol);
}
ObligationType::Linearity
| ObligationType::Symmetry
| ObligationType::Associativity
| ObligationType::Conservation
| ObligationType::Ordering
| ObligationType::Completeness
| ObligationType::Soundness
| ObligationType::Involution
| ObligationType::Determinism
| ObligationType::Roundtrip
| ObligationType::StateMachine
| ObligationType::Classification
| ObligationType::Independence
| ObligationType::Termination
| ObligationType::Precondition
| ObligationType::Postcondition
| ObligationType::Frame
| ObligationType::LoopInvariant
| ObligationType::LoopVariant
| ObligationType::OldState
| ObligationType::Subcontract
| ObligationType::Safety
| ObligationType::Liveness => {
generate_wired_generic(out, ob);
}
}
out.push_str(" }\n\n");
}
#[derive(Debug)]
enum KernelType {
FreeFunction(String),
StructMethod,
TensorMethod(String),
}
fn classify_kernel(binding: &KernelBinding) -> KernelType {
let func = binding.function.as_deref().unwrap_or("unknown");
if func.contains("::forward") {
KernelType::StructMethod
} else if func.starts_with("Tensor::") {
KernelType::TensorMethod(func.strip_prefix("Tensor::").unwrap_or(func).to_string())
} else {
KernelType::FreeFunction(func.to_string())
}
}
fn emit_free_fn_call(out: &mut String, name: &str, input_var: &str, output_var: &str) {
out.push_str(" let n = data.len();\n");
out.push_str(&format!(
" let {input_var} = \
aprender::autograd::Tensor::new(&data, &[1, n]);\n"
));
if name == "softmax" {
out.push_str(&format!(
" let {output_var} = \
softmax(&{input_var}, -1);\n"
));
} else {
out.push_str(&format!(
" let {output_var} = \
{name}(&{input_var});\n"
));
}
}
fn emit_fallback_call(out: &mut String, input_var: &str, output_var: &str, kernel: &KernelType) {
out.push_str(" let n = data.len();\n");
out.push_str(&format!(
" let {input_var} = \
aprender::autograd::Tensor::new(&data, &[1, n]);\n"
));
match kernel {
KernelType::StructMethod => {
out.push_str(&format!(
" // TODO: wire up struct constructor \
+ .forward(&{input_var})\n"
));
out.push_str(&format!(
" let {output_var} = \
{input_var}.clone();\n"
));
}
KernelType::TensorMethod(method) => {
out.push_str(&format!(
" let {output_var} = \
{input_var}.{method}(&{input_var});\n"
));
}
KernelType::FreeFunction(_) => unreachable!(),
}
}
fn generate_wired_invariant(
out: &mut String,
ob: &crate::schema::ProofObligation,
tol: &str,
kernel: &KernelType,
) {
out.push_str(
" data in proptest::collection::vec(\
-100.0f32..100.0, 1..64usize)\n",
);
out.push_str(" ) {\n");
if let KernelType::FreeFunction(name) = kernel {
emit_free_fn_call(out, name, "x", "y");
} else {
emit_fallback_call(out, "x", "y", kernel);
}
out.push_str(&format!(" // Invariant: {}\n", ob.property));
out.push_str(&format!(" let _ = {tol}f64; // tolerance\n"));
out.push_str(" let y_data = y.data();\n");
out.push_str(" for &val in y_data {\n");
out.push_str(
" prop_assert!(val.is_finite(), \
\"output not finite: {}\", val);\n",
);
out.push_str(" }\n");
}
fn generate_wired_bound(
out: &mut String,
ob: &crate::schema::ProofObligation,
kernel: &KernelType,
) {
out.push_str(
" data in proptest::collection::vec(\
-100.0f32..100.0, 1..64usize)\n",
);
out.push_str(" ) {\n");
if let KernelType::FreeFunction(name) = kernel {
emit_free_fn_call(out, name, "x", "y");
} else {
emit_fallback_call(out, "x", "y", kernel);
}
out.push_str(&format!(" // Bound: {}\n", ob.property));
out.push_str(" let y_data = y.data();\n");
out.push_str(" for &val in y_data {\n");
out.push_str(
" prop_assert!(val.is_finite(), \
\"output not finite\");\n",
);
out.push_str(" }\n");
}
fn generate_wired_monotonicity(
out: &mut String,
ob: &crate::schema::ProofObligation,
kernel: &KernelType,
) {
out.push_str(
" data in proptest::collection::vec(\
-100.0f32..100.0, 2..64usize)\n",
);
out.push_str(" ) {\n");
if let KernelType::FreeFunction(name) = kernel {
emit_free_fn_call(out, name, "x", "y");
} else {
emit_fallback_call(out, "x", "y", kernel);
}
out.push_str(&format!(" // Monotonicity: {}\n", ob.property));
out.push_str(" let x_data = x.data();\n");
out.push_str(" let y_data = y.data();\n");
out.push_str(
" for i in 0..x_data.len() {\n\
\x20 for j in 0..x_data.len() {\n\
\x20 if x_data[i] > x_data[j] {\n\
\x20 prop_assert!(\
y_data[i] >= y_data[j],\n\
\x20 \"monotonicity \
violated: x[{}]={} > x[{}]={} but \
y[{}]={} < y[{}]={}\",\n\
\x20 i, x_data[i], j, \
x_data[j], i, y_data[i], j, y_data[j]);\n\
\x20 }\n\
\x20 }\n\
\x20 }\n",
);
}
fn generate_wired_idempotency(
out: &mut String,
ob: &crate::schema::ProofObligation,
tol: &str,
kernel: &KernelType,
) {
out.push_str(
" data in proptest::collection::vec(\
-10.0f32..10.0, 1..32usize)\n",
);
out.push_str(" ) {\n");
if let KernelType::FreeFunction(name) = kernel {
out.push_str(" let n = data.len();\n");
out.push_str(
" let x = aprender::autograd::Tensor\
::new(&data, &[1, n]);\n",
);
if name == "softmax" {
out.push_str(" let once = softmax(&x, -1);\n");
out.push_str(" let twice = softmax(&once, -1);\n");
} else {
out.push_str(&format!(" let once = {name}(&x);\n"));
out.push_str(&format!(" let twice = {name}(&once);\n"));
}
} else {
out.push_str(" let n = data.len();\n");
out.push_str(
" let x = aprender::autograd::Tensor\
::new(&data, &[1, n]);\n",
);
out.push_str(" let once = x.clone(); // TODO\n");
out.push_str(" let twice = once.clone(); // TODO\n");
}
out.push_str(&format!(" // Idempotency: {}\n", ob.property));
out.push_str(
" let once_data = once.data();\n\
\x20 let twice_data = twice.data();\n\
\x20 for (a, b) in once_data.iter()\
.zip(twice_data.iter()) {\n",
);
out.push_str(&format!(
" prop_assert!(({tol}f64 > \
(f64::from(*a) - f64::from(*b)).abs()),\n\
\x20 \"idempotency violated: \
f(x)={{}} but f(f(x))={{}}\", a, b);\n"
));
out.push_str(" }\n");
}
fn generate_wired_equivalence(out: &mut String, ob: &crate::schema::ProofObligation, tol: &str) {
out.push_str(
" data in proptest::collection::vec(\
-100.0f32..100.0, 1..64usize)\n",
);
out.push_str(" ) {\n");
out.push_str(&format!(" // Equivalence: {}\n", ob.property));
out.push_str(&format!(" let _ = {tol}f64; // ULP tolerance\n"));
out.push_str(
" let _ = &data;\n\
\x20 // Equivalence tests require two \
implementations to compare.\n\
\x20 // For SIMD vs scalar, this is \
trueno's responsibility.\n",
);
}
fn generate_wired_generic(out: &mut String, ob: &crate::schema::ProofObligation) {
out.push_str(
" data in proptest::collection::vec(\
-100.0f32..100.0, 1..64usize)\n",
);
out.push_str(" ) {\n");
out.push_str(&format!(
" // {}: {}\n",
ob.obligation_type, ob.property
));
out.push_str(
" let _ = &data;\n\
\x20 // TODO: wire up obligation test\n",
);
}
fn obligation_fn_name(ob: &crate::schema::ProofObligation, index: usize) -> String {
let base = ob
.property
.to_lowercase()
.replace(|c: char| !c.is_alphanumeric(), "_")
.trim_matches('_')
.to_string();
if base.is_empty() {
format!("prop_obligation_{index}")
} else {
format!("prop_{base}")
}
}
fn simple_hash(input: &str) -> String {
let mut h: u64 = 0xcbf2_9ce4_8422_2325;
for b in input.bytes() {
h ^= u64::from(b);
h = h.wrapping_mul(0x0100_0000_01b3);
}
format!("{h:016x}")
}
#[cfg(test)]
mod tests {
include!("wired_tests1.rs");
include!("wired_tests2.rs");
}