aprender-contracts 0.31.1

Papers to Math to Contracts in Code — YAML contract parsing, validation, scaffold generation, and Kani harness codegen for provable Rust kernels
Documentation
//! Type invariant code generation.
//!
//! Generates `Invariant` trait implementations and Kani preservation
//! harnesses from the `type_invariants` section of a YAML contract.

use std::fmt::Write;

use crate::schema::Contract;

/// Generate an `Invariant` trait definition + impls from a contract.
///
/// Returns empty string if the contract has no type invariants.
pub fn generate_invariant_trait(contract: &Contract) -> String {
    if contract.type_invariants.is_empty() {
        return String::new();
    }

    let mut out = String::with_capacity(1024);

    out.push_str("/// Trait for types with provable invariants (Meyer's class invariants).\n");
    out.push_str("pub trait Invariant {\n");
    out.push_str("    /// Returns true if all invariants hold for this instance.\n");
    out.push_str("    fn is_valid(&self) -> bool;\n");
    out.push_str("}\n\n");

    // Group invariants by type
    let mut by_type: std::collections::BTreeMap<&str, Vec<&crate::schema::TypeInvariant>> =
        std::collections::BTreeMap::new();
    for inv in &contract.type_invariants {
        by_type.entry(&inv.type_name).or_default().push(inv);
    }

    for (type_name, invariants) in &by_type {
        let _ = writeln!(out, "impl Invariant for {type_name} {{");
        let _ = writeln!(out, "    fn is_valid(&self) -> bool {{");

        for (i, inv) in invariants.iter().enumerate() {
            if let Some(ref desc) = inv.description {
                let _ = writeln!(out, "        // {desc}");
            }
            if i < invariants.len() - 1 {
                let _ = writeln!(out, "        ({}) &&", inv.predicate);
            } else {
                let _ = writeln!(out, "        ({})", inv.predicate);
            }
        }

        let _ = writeln!(out, "    }}");
        let _ = writeln!(out, "}}\n");
    }

    out
}

/// Generate Kani preservation harnesses for type invariants.
///
/// For each type with invariants, generates a harness that:
/// 1. Creates an arbitrary instance
/// 2. Assumes the invariant holds
/// 3. (Placeholder for operation)
/// 4. Asserts the invariant still holds
pub fn generate_invariant_harnesses(contract: &Contract) -> String {
    if contract.type_invariants.is_empty() {
        return String::new();
    }

    let mut out = String::with_capacity(1024);

    out.push_str("#[cfg(kani)]\n");
    out.push_str("mod invariant_verification {\n");
    out.push_str("    use super::*;\n\n");

    // Group by type
    let mut types_seen: std::collections::BTreeSet<&str> = std::collections::BTreeSet::new();
    for inv in &contract.type_invariants {
        if types_seen.insert(&inv.type_name) {
            let fn_name = inv.type_name.to_lowercase().replace("::", "_");
            let _ = writeln!(
                out,
                "    /// Verify invariant is satisfiable for `{}`.",
                inv.type_name
            );
            let _ = writeln!(out, "    #[kani::proof]");
            let _ = writeln!(out, "    fn verify_{fn_name}_invariant_satisfiable() {{");
            let _ = writeln!(out, "        let v: {} = kani::any();", inv.type_name);
            let _ = writeln!(out, "        kani::assume(v.is_valid());");
            let _ = writeln!(
                out,
                "        // Invariant is satisfiable — Kani found a valid instance"
            );
            let _ = writeln!(out, "    }}\n");
        }
    }

    out.push_str("}\n");
    out
}

/// Generate the full invariant output (trait + harnesses).
pub fn generate_invariants(contract: &Contract) -> String {
    let trait_code = generate_invariant_trait(contract);
    let harness_code = generate_invariant_harnesses(contract);

    if trait_code.is_empty() {
        return String::new();
    }

    let mut out = String::with_capacity(trait_code.len() + harness_code.len() + 128);
    out.push_str("// Auto-generated type invariant code from YAML contract.\n");
    out.push_str("// Regenerate with: pv invariants <contract.yaml>\n\n");
    out.push_str(&trait_code);
    out.push_str(&harness_code);
    out
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::schema::parse_contract_str;

    #[test]
    fn empty_contract_produces_no_output() {
        let contract = parse_contract_str(
            r#"
metadata:
  version: "1.0.0"
  description: "No invariants"
  references: ["Test"]
equations:
  f:
    formula: "f(x) = x"
falsification_tests: []
"#,
        )
        .unwrap();
        assert!(generate_invariants(&contract).is_empty());
    }

    #[test]
    fn generates_trait_impl() {
        let contract = parse_contract_str(
            r#"
metadata:
  version: "1.0.0"
  description: "With invariants"
  references: ["Test"]
equations:
  f:
    formula: "f(x) = x"
type_invariants:
  - name: tensor_valid
    type: "ValidatedTensor"
    predicate: "!self.dims.is_empty()"
    description: "At least one dimension"
  - name: tensor_size
    type: "ValidatedTensor"
    predicate: "self.dims.iter().product::<usize>() == self.data.len()"
    description: "Data length matches dimensions"
falsification_tests: []
"#,
        )
        .unwrap();

        let output = generate_invariants(&contract);
        assert!(output.contains("pub trait Invariant"));
        assert!(output.contains("impl Invariant for ValidatedTensor"));
        assert!(output.contains("!self.dims.is_empty()"));
        assert!(output.contains("self.dims.iter().product"));
        assert!(output.contains("#[kani::proof]"));
        assert!(output.contains("verify_validatedtensor_invariant_satisfiable"));
    }

    #[test]
    fn multiple_types_get_separate_impls() {
        let contract = parse_contract_str(
            r#"
metadata:
  version: "1.0.0"
  description: "Multi-type"
  references: ["Test"]
equations:
  f:
    formula: "f(x) = x"
type_invariants:
  - name: a_valid
    type: "TypeA"
    predicate: "self.x > 0"
  - name: b_valid
    type: "TypeB"
    predicate: "!self.items.is_empty()"
falsification_tests: []
"#,
        )
        .unwrap();

        let output = generate_invariants(&contract);
        assert!(output.contains("impl Invariant for TypeA"));
        assert!(output.contains("impl Invariant for TypeB"));
        assert!(output.contains("verify_typea_invariant_satisfiable"));
        assert!(output.contains("verify_typeb_invariant_satisfiable"));
    }
}