use std::fmt::Write;
use crate::schema::Contract;
pub fn generate_invariant_trait(contract: &Contract) -> String {
if contract.type_invariants.is_empty() {
return String::new();
}
let mut out = String::with_capacity(1024);
out.push_str("/// Trait for types with provable invariants (Meyer's class invariants).\n");
out.push_str("pub trait Invariant {\n");
out.push_str(" /// Returns true if all invariants hold for this instance.\n");
out.push_str(" fn is_valid(&self) -> bool;\n");
out.push_str("}\n\n");
let mut by_type: std::collections::BTreeMap<&str, Vec<&crate::schema::TypeInvariant>> =
std::collections::BTreeMap::new();
for inv in &contract.type_invariants {
by_type.entry(&inv.type_name).or_default().push(inv);
}
for (type_name, invariants) in &by_type {
let _ = writeln!(out, "impl Invariant for {type_name} {{");
let _ = writeln!(out, " fn is_valid(&self) -> bool {{");
for (i, inv) in invariants.iter().enumerate() {
if let Some(ref desc) = inv.description {
let _ = writeln!(out, " // {desc}");
}
if i < invariants.len() - 1 {
let _ = writeln!(out, " ({}) &&", inv.predicate);
} else {
let _ = writeln!(out, " ({})", inv.predicate);
}
}
let _ = writeln!(out, " }}");
let _ = writeln!(out, "}}\n");
}
out
}
pub fn generate_invariant_harnesses(contract: &Contract) -> String {
if contract.type_invariants.is_empty() {
return String::new();
}
let mut out = String::with_capacity(1024);
out.push_str("#[cfg(kani)]\n");
out.push_str("mod invariant_verification {\n");
out.push_str(" use super::*;\n\n");
let mut types_seen: std::collections::BTreeSet<&str> = std::collections::BTreeSet::new();
for inv in &contract.type_invariants {
if types_seen.insert(&inv.type_name) {
let fn_name = inv.type_name.to_lowercase().replace("::", "_");
let _ = writeln!(
out,
" /// Verify invariant is satisfiable for `{}`.",
inv.type_name
);
let _ = writeln!(out, " #[kani::proof]");
let _ = writeln!(out, " fn verify_{fn_name}_invariant_satisfiable() {{");
let _ = writeln!(out, " let v: {} = kani::any();", inv.type_name);
let _ = writeln!(out, " kani::assume(v.is_valid());");
let _ = writeln!(
out,
" // Invariant is satisfiable — Kani found a valid instance"
);
let _ = writeln!(out, " }}\n");
}
}
out.push_str("}\n");
out
}
pub fn generate_invariants(contract: &Contract) -> String {
let trait_code = generate_invariant_trait(contract);
let harness_code = generate_invariant_harnesses(contract);
if trait_code.is_empty() {
return String::new();
}
let mut out = String::with_capacity(trait_code.len() + harness_code.len() + 128);
out.push_str("// Auto-generated type invariant code from YAML contract.\n");
out.push_str("// Regenerate with: pv invariants <contract.yaml>\n\n");
out.push_str(&trait_code);
out.push_str(&harness_code);
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::parse_contract_str;
#[test]
fn empty_contract_produces_no_output() {
let contract = parse_contract_str(
r#"
metadata:
version: "1.0.0"
description: "No invariants"
references: ["Test"]
equations:
f:
formula: "f(x) = x"
falsification_tests: []
"#,
)
.unwrap();
assert!(generate_invariants(&contract).is_empty());
}
#[test]
fn generates_trait_impl() {
let contract = parse_contract_str(
r#"
metadata:
version: "1.0.0"
description: "With invariants"
references: ["Test"]
equations:
f:
formula: "f(x) = x"
type_invariants:
- name: tensor_valid
type: "ValidatedTensor"
predicate: "!self.dims.is_empty()"
description: "At least one dimension"
- name: tensor_size
type: "ValidatedTensor"
predicate: "self.dims.iter().product::<usize>() == self.data.len()"
description: "Data length matches dimensions"
falsification_tests: []
"#,
)
.unwrap();
let output = generate_invariants(&contract);
assert!(output.contains("pub trait Invariant"));
assert!(output.contains("impl Invariant for ValidatedTensor"));
assert!(output.contains("!self.dims.is_empty()"));
assert!(output.contains("self.dims.iter().product"));
assert!(output.contains("#[kani::proof]"));
assert!(output.contains("verify_validatedtensor_invariant_satisfiable"));
}
#[test]
fn multiple_types_get_separate_impls() {
let contract = parse_contract_str(
r#"
metadata:
version: "1.0.0"
description: "Multi-type"
references: ["Test"]
equations:
f:
formula: "f(x) = x"
type_invariants:
- name: a_valid
type: "TypeA"
predicate: "self.x > 0"
- name: b_valid
type: "TypeB"
predicate: "!self.items.is_empty()"
falsification_tests: []
"#,
)
.unwrap();
let output = generate_invariants(&contract);
assert!(output.contains("impl Invariant for TypeA"));
assert!(output.contains("impl Invariant for TypeB"));
assert!(output.contains("verify_typea_invariant_satisfiable"));
assert!(output.contains("verify_typeb_invariant_satisfiable"));
}
}