Skip to main content

provable_contracts/
codegen.rs

1//! Code generation from YAML contracts → Rust `debug_assert`!() checks.
2//!
3//! Reads contract YAML files and generates a Rust module with assertion
4//! functions that can be called from production code. Zero cost in release.
5//!
6//! Also generates Lean 4 obligation stubs for unproved theorems.
7
8use crate::schema::Contract;
9use std::path::Path;
10
11/// Generated contract enforcement code for a single contract.
12#[derive(Debug, Clone)]
13pub struct GeneratedContract {
14    /// Contract name (from YAML filename stem).
15    pub name: String,
16    /// Generated Rust assertion functions.
17    pub rust_assertions: String,
18    /// Generated Lean 4 theorem stubs (for unproved obligations).
19    pub lean_stubs: String,
20    /// Number of preconditions generated.
21    pub precondition_count: usize,
22    /// Number of postconditions generated.
23    pub postcondition_count: usize,
24    /// Number of Lean theorems linked.
25    pub lean_theorem_count: usize,
26    /// Number of invariant assertions generated.
27    pub invariant_count: usize,
28}
29
30/// Generate Rust assertion code from a contract's equations.
31///
32/// For each equation with `preconditions` or `postconditions`, generates:
33/// ```rust,ignore
34/// pub fn check_gemv_preconditions(a_len: usize, rows: usize, cols: usize) {
35///     debug_assert!(a_len == rows * cols, "Pre: a.len() == rows * cols");
36/// }
37/// ```
38pub fn generate_from_contract(name: &str, contract: &Contract) -> GeneratedContract {
39    let mut rust = String::new();
40    let mut lean = String::new();
41    let mut pre_count = 0;
42    let mut post_count = 0;
43    let mut invariant_count = 0;
44    let mut lean_count = 0;
45
46    rust.push_str(&format!(
47        "// Auto-generated from contracts/{name}.yaml — DO NOT EDIT\n"
48    ));
49    rust.push_str(&format!("// Contract: {name}\n\n"));
50
51    for (eq_name, equation) in &contract.equations {
52        let macro_name = eq_name.replace('-', "_").to_lowercase();
53        pre_count +=
54            emit_precondition_macro(&mut rust, eq_name, &macro_name, &equation.preconditions);
55        post_count +=
56            emit_postcondition_macro(&mut rust, eq_name, &macro_name, &equation.postconditions);
57        invariant_count +=
58            emit_invariant_macro(&mut rust, eq_name, &macro_name, &equation.invariants);
59        emit_combined_macro(
60            &mut rust,
61            eq_name,
62            &macro_name,
63            &equation.preconditions,
64            &equation.postconditions,
65        );
66
67        // Lean theorem linkage
68        if let Some(ref theorem) = equation.lean_theorem {
69            lean.push_str(&format!("-- Equation: {eq_name}\n"));
70            lean.push_str(&format!("-- Lean theorem: {theorem}\n"));
71            lean.push_str(&format!(
72                "-- Formula: {}\n\n",
73                equation.formula.lines().next().unwrap_or("")
74            ));
75            lean_count += 1;
76        }
77    }
78
79    GeneratedContract {
80        name: name.to_string(),
81        rust_assertions: rust,
82        lean_stubs: lean,
83        precondition_count: pre_count,
84        postcondition_count: post_count,
85        lean_theorem_count: lean_count,
86        invariant_count,
87    }
88}
89
90/// Emit a precondition macro for an equation. Returns number of assertions emitted.
91fn emit_precondition_macro(
92    rust: &mut String,
93    eq_name: &str,
94    macro_name: &str,
95    pres: &[String],
96) -> usize {
97    if pres.is_empty() {
98        return 0;
99    }
100    let uses_domain = pres.iter().any(|p| {
101        p.contains("==")
102            || p.contains("eps")
103            || p.contains("weight")
104            || p.contains("freqs")
105            || p.contains("scale")
106            || p.contains('.') && !p.contains("is_empty")
107    });
108    let mut count = 0;
109    rust.push_str(&format!("/// Preconditions for equation `{eq_name}`.\n"));
110    if uses_domain {
111        let pv = detect_primary_var(pres);
112        // Use _pv_ prefix to avoid macro hygiene issues in cross-crate builds
113        let safe_pv = format!("_pv_{pv}");
114        rust.push_str(&format!(
115            "/// Domain-specific. Call: `contract_pre_{macro_name}!(slice_expr)`\n"
116        ));
117        rust.push_str(&format!("macro_rules! contract_pre_{macro_name} {{\n"));
118        // Zero-arg form: no-op (proc-macro compatibility)
119        rust.push_str("    () => {{}};\n");
120        rust.push_str("    ($input:expr) => {{\n");
121        rust.push_str(&format!("        let {safe_pv} = &$input;\n"));
122        for pre in pres {
123            if has_unbound_vars(pre, &pv) {
124                continue;
125            }
126            let mapped = pre.replace(&pv, &safe_pv);
127            let esc = pre.replace('"', "\\\"");
128            rust.push_str(&format!("        debug_assert!({mapped},\n            \"Contract {eq_name}: precondition violated — {esc}\");\n"));
129            count += 1;
130        }
131        rust.push_str("    }};\n}\n\n");
132    } else {
133        rust.push_str(&format!(
134            "/// Call at function entry: `contract_pre_{macro_name}!(input_expr)`\n"
135        ));
136        rust.push_str(&format!("macro_rules! contract_pre_{macro_name} {{\n"));
137        rust.push_str("    () => {{}};\n");
138        rust.push_str("    ($input:expr) => {{\n        let _contract_input = &$input;\n");
139        for pre in pres {
140            // Map common variable names to _contract_input
141            let mut assertion = pre
142                .replace("input", "_contract_input")
143                .replace("x.", "_contract_input.")
144                .replace("x)", "_contract_input)");
145            // Handle !var.method() patterns — map leading var to _contract_input
146            // Only for safe methods: is_empty, len, is_finite, iter (type-polymorphic)
147            if has_unbound_vars(&assertion, "_contract_input") {
148                let stripped = pre.trim_start_matches('!');
149                if let Some(dot) = stripped.find('.') {
150                    let var = &stripped[..dot];
151                    let method = &stripped[dot + 1..];
152                    // Only map for methods that exist on slices/vecs (not is_empty which fails on scalars)
153                    let safe_method = method.starts_with("len()")
154                        || method.starts_with("iter()")
155                        || method.starts_with("is_finite()");
156                    if safe_method
157                        && !var.is_empty()
158                        && var.chars().all(|c| c.is_alphanumeric() || c == '_')
159                    {
160                        let mapped = pre.replace(var, "_contract_input");
161                        if !has_unbound_vars(&mapped, "_contract_input") {
162                            assertion = mapped;
163                        }
164                    }
165                }
166            }
167            // Skip assertions that still have unbound variables after substitution
168            if has_unbound_vars(&assertion, "_contract_input") {
169                continue;
170            }
171            let esc = pre.replace('"', "\\\"");
172            rust.push_str(&format!("        debug_assert!({assertion},\n            \"Contract {eq_name}: precondition violated — {esc}\");\n"));
173            count += 1;
174        }
175        rust.push_str("    }};\n}\n\n");
176    }
177    count
178}
179
180/// Emit a postcondition macro for an equation. Returns number of assertions emitted.
181fn emit_postcondition_macro(
182    rust: &mut String,
183    eq_name: &str,
184    macro_name: &str,
185    posts: &[String],
186) -> usize {
187    if posts.is_empty() {
188        return 0;
189    }
190    let mut count = 0;
191    rust.push_str(&format!("/// Postconditions for equation `{eq_name}`.\n"));
192    rust.push_str(&format!(
193        "/// Call before return: `contract_post_{macro_name}!(result_expr)`\n"
194    ));
195    rust.push_str(&format!("macro_rules! contract_post_{macro_name} {{\n"));
196    rust.push_str("    ($result:expr) => {{\n        let _contract_result = &$result;\n");
197    for post in posts {
198        // Replace result with *_contract_result for scalar comparisons (>= 0.0, etc.)
199        // and _contract_result for method calls (.is_finite(), .iter(), .len())
200        let fixed = if post.contains("result.") || post.contains("result)") {
201            post.replace("result", "_contract_result")
202        } else {
203            // Scalar comparison: result >= 0.0 → *_contract_result >= 0.0
204            post.replace("result", "*_contract_result")
205        };
206        // Skip postconditions that reference unbound variables (same hygiene fix as preconditions)
207        if has_unbound_vars(&fixed, "_contract_result") {
208            continue;
209        }
210        let esc = post.replace('"', "\\\"");
211        rust.push_str(&format!("        debug_assert!({fixed}, \"Contract {eq_name}: postcondition violated — {esc}\");\n"));
212        count += 1;
213    }
214    rust.push_str("    }};\n}\n\n");
215    count
216}
217
218/// Emit an invariant macro for an equation. Returns number of assertions emitted.
219/// Invariants are checked as postconditions (after computation completes).
220fn emit_invariant_macro(
221    rust: &mut String,
222    eq_name: &str,
223    macro_name: &str,
224    invariants: &[String],
225) -> usize {
226    if invariants.is_empty() {
227        return 0;
228    }
229    let mut count = 0;
230    rust.push_str(&format!("/// Invariants for equation `{eq_name}`.\n"));
231    rust.push_str(&format!(
232        "/// Check after computation: `contract_inv_{macro_name}!(result_expr)`\n"
233    ));
234    rust.push_str(&format!("macro_rules! contract_inv_{macro_name} {{\n"));
235    rust.push_str("    () => {{}};\n");
236    rust.push_str("    ($result:expr) => {{\n        let _contract_result = &$result;\n");
237    for inv in invariants {
238        // Try to make the invariant into a compilable assertion
239        let fixed = if inv.contains("result.") || inv.contains("result)") {
240            inv.replace("result", "_contract_result")
241        } else if inv.contains(">=")
242            || inv.contains("<=")
243            || inv.contains("==")
244            || inv.contains("> ")
245            || inv.contains("< ")
246        {
247            inv.replace("result", "*_contract_result")
248        } else {
249            continue; // Skip prose invariants that aren't Rust expressions
250        };
251        // Skip invariants with unbound variables
252        if has_unbound_vars(&fixed, "_contract_result") {
253            continue;
254        }
255        let esc = inv.replace('"', "\\\"");
256        rust.push_str(&format!("        debug_assert!({fixed}, \"Contract {eq_name}: invariant violated \u{2014} {esc}\");\n"));
257        count += 1;
258    }
259    rust.push_str("    }};\n}\n\n");
260    count
261}
262
263/// Emit a combined pre+post wrapper macro.
264fn emit_combined_macro(
265    rust: &mut String,
266    eq_name: &str,
267    macro_name: &str,
268    pres: &[String],
269    posts: &[String],
270) {
271    if pres.is_empty() || posts.is_empty() {
272        return;
273    }
274    rust.push_str(&format!(
275        "/// Combined pre+post contract for equation `{eq_name}`.\n"
276    ));
277    rust.push_str(&format!("macro_rules! contract_{macro_name} {{\n"));
278    rust.push_str("    ($input:expr, $body:expr) => {{\n");
279    rust.push_str(&format!("        contract_pre_{macro_name}!($input);\n"));
280    rust.push_str("        let _contract_result = $body;\n");
281    rust.push_str(&format!(
282        "        contract_post_{macro_name}!(_contract_result);\n"
283    ));
284    rust.push_str("        _contract_result\n");
285    rust.push_str("    }};\n}\n\n");
286}
287
288/// Detect the primary variable name used in preconditions.
289/// Scans for the first `<var>.` pattern (e.g., `x.len()` → `x`).
290fn detect_primary_var(preconditions: &[String]) -> String {
291    for pre in preconditions {
292        // Match patterns like "x.len()", "logits.iter()", "a.len()"
293        if let Some(dot_pos) = pre.find('.') {
294            let candidate = &pre[..dot_pos];
295            // Must be a simple identifier (no spaces, operators)
296            if !candidate.is_empty()
297                && candidate.chars().all(|c| c.is_alphanumeric() || c == '_')
298                && candidate != "result"
299            {
300                return candidate.to_string();
301            }
302        }
303    }
304    "x".to_string() // default fallback
305}
306
307/// Check if a precondition expression references variables beyond the primary
308/// and standard library methods. Returns true if it has unbound names.
309fn has_unbound_vars(expr: &str, primary_var: &str) -> bool {
310    // Extract all identifiers that appear before `.` (method call targets)
311    // or standalone (bare variables like m, k, n)
312    let safe_names = [
313        primary_var,
314        "_contract_input",
315        "true",
316        "false",
317        "f32",
318        "f64",
319        "usize",
320        "i32",
321        "i64",
322    ];
323    // Tokenize crudely: split on operators and delimiters
324    for token in expr.split(|c: char| "().&|!<>=+- */%,;{}[]".contains(c)) {
325        let token = token.trim();
326        if token.is_empty() || token.chars().next().is_some_and(|c| c.is_ascii_digit()) {
327            continue; // skip empty, numeric literals
328        }
329        // Skip known safe identifiers and closures
330        if safe_names.contains(&token)
331            || token == "v"
332            || token == "id"
333            || token.starts_with("is_")
334            || token == "iter"
335            || token == "all"
336            || token == "any"
337            || token == "len"
338            || token == "abs"
339            || token == "sum"
340        {
341            continue;
342        }
343        // This token is an unbound variable
344        if token.chars().all(|c| c.is_alphanumeric() || c == '_') && token.len() <= 20 {
345            return true;
346        }
347    }
348    false
349}
350
351/// Generate code for all contracts in a directory (recursive).
352pub fn generate_all(contract_dir: &Path) -> Vec<GeneratedContract> {
353    let mut yaml_paths = Vec::new();
354    collect_yaml_files(contract_dir, &mut yaml_paths);
355
356    let mut results = Vec::new();
357    for path in &yaml_paths {
358        let stem = path
359            .file_stem()
360            .and_then(|s| s.to_str())
361            .unwrap_or("unknown")
362            .to_string();
363
364        if let Ok(contract) = crate::schema::parse_contract(path) {
365            let generated = generate_from_contract(&stem, &contract);
366            if generated.precondition_count > 0
367                || generated.postcondition_count > 0
368                || generated.lean_theorem_count > 0
369            {
370                results.push(generated);
371            }
372        }
373    }
374
375    results.sort_by(|a, b| a.name.cmp(&b.name));
376    results
377}
378
379/// Recursively collect `.yaml` contract files, skipping non-contract directories.
380fn collect_yaml_files(dir: &Path, out: &mut Vec<std::path::PathBuf>) {
381    let Ok(entries) = std::fs::read_dir(dir) else {
382        return;
383    };
384    for entry in entries.flatten() {
385        let path = entry.path();
386        if path.is_dir() {
387            let dirname = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
388            if dirname == "kaizen" || dirname == "legacy" || dirname == "pipelines" {
389                continue;
390            }
391            collect_yaml_files(&path, out);
392        } else if path.extension().and_then(|e| e.to_str()) == Some("yaml")
393            && path.file_name().and_then(|n| n.to_str()) != Some("binding.yaml")
394        {
395            out.push(path);
396        }
397    }
398}
399
400/// Write generated Rust code to a file.
401pub fn write_rust_module(contracts: &[GeneratedContract], output: &Path) -> std::io::Result<()> {
402    let mut content = String::new();
403    content.push_str("// Auto-generated contract assertions from YAML — DO NOT EDIT.\n");
404    content.push_str("// Zero cost in release builds (debug_assert!).\n");
405    content.push_str("// Regenerate: pv codegen contracts/ -o src/generated_contracts.rs\n");
406    content.push_str(
407        "// Include:   #[macro_use] #[allow(unused_macros)] mod generated_contracts;\n\n",
408    );
409
410    let mut total_pre = 0;
411    let mut total_post = 0;
412    let mut total_inv = 0;
413
414    for c in contracts {
415        content.push_str(&c.rust_assertions);
416        total_pre += c.precondition_count;
417        total_post += c.postcondition_count;
418        total_inv += c.invariant_count;
419    }
420
421    content.push_str(&format!(
422        "// Total: {} preconditions, {} postconditions, {} invariants from {} contracts\n",
423        total_pre,
424        total_post,
425        total_inv,
426        contracts.len()
427    ));
428
429    std::fs::write(output, content)
430}
431
432#[cfg(test)]
433#[path = "codegen_tests.rs"]
434mod tests;