1use crate::schema::Contract;
9use std::path::Path;
10
11#[derive(Debug, Clone)]
13pub struct GeneratedContract {
14 pub name: String,
16 pub rust_assertions: String,
18 pub lean_stubs: String,
20 pub precondition_count: usize,
22 pub postcondition_count: usize,
24 pub lean_theorem_count: usize,
26 pub invariant_count: usize,
28}
29
30pub fn generate_from_contract(name: &str, contract: &Contract) -> GeneratedContract {
39 let mut rust = String::new();
40 let mut lean = String::new();
41 let mut pre_count = 0;
42 let mut post_count = 0;
43 let mut invariant_count = 0;
44 let mut lean_count = 0;
45
46 rust.push_str(&format!(
47 "// Auto-generated from contracts/{name}.yaml — DO NOT EDIT\n"
48 ));
49 rust.push_str(&format!("// Contract: {name}\n\n"));
50
51 for (eq_name, equation) in &contract.equations {
52 let macro_name = eq_name.replace('-', "_").to_lowercase();
53 pre_count +=
54 emit_precondition_macro(&mut rust, eq_name, ¯o_name, &equation.preconditions);
55 post_count +=
56 emit_postcondition_macro(&mut rust, eq_name, ¯o_name, &equation.postconditions);
57 invariant_count +=
58 emit_invariant_macro(&mut rust, eq_name, ¯o_name, &equation.invariants);
59 emit_combined_macro(
60 &mut rust,
61 eq_name,
62 ¯o_name,
63 &equation.preconditions,
64 &equation.postconditions,
65 );
66
67 if let Some(ref theorem) = equation.lean_theorem {
69 lean.push_str(&format!("-- Equation: {eq_name}\n"));
70 lean.push_str(&format!("-- Lean theorem: {theorem}\n"));
71 lean.push_str(&format!(
72 "-- Formula: {}\n\n",
73 equation.formula.lines().next().unwrap_or("")
74 ));
75 lean_count += 1;
76 }
77 }
78
79 GeneratedContract {
80 name: name.to_string(),
81 rust_assertions: rust,
82 lean_stubs: lean,
83 precondition_count: pre_count,
84 postcondition_count: post_count,
85 lean_theorem_count: lean_count,
86 invariant_count,
87 }
88}
89
90fn emit_precondition_macro(
92 rust: &mut String,
93 eq_name: &str,
94 macro_name: &str,
95 pres: &[String],
96) -> usize {
97 if pres.is_empty() {
98 return 0;
99 }
100 let uses_domain = pres.iter().any(|p| {
101 p.contains("==")
102 || p.contains("eps")
103 || p.contains("weight")
104 || p.contains("freqs")
105 || p.contains("scale")
106 || p.contains('.') && !p.contains("is_empty")
107 });
108 let mut count = 0;
109 rust.push_str(&format!("/// Preconditions for equation `{eq_name}`.\n"));
110 if uses_domain {
111 let pv = detect_primary_var(pres);
112 let safe_pv = format!("_pv_{pv}");
114 rust.push_str(&format!(
115 "/// Domain-specific. Call: `contract_pre_{macro_name}!(slice_expr)`\n"
116 ));
117 rust.push_str(&format!("macro_rules! contract_pre_{macro_name} {{\n"));
118 rust.push_str(" () => {{}};\n");
120 rust.push_str(" ($input:expr) => {{\n");
121 rust.push_str(&format!(" let {safe_pv} = &$input;\n"));
122 for pre in pres {
123 if has_unbound_vars(pre, &pv) {
124 continue;
125 }
126 let mapped = pre.replace(&pv, &safe_pv);
127 let esc = pre.replace('"', "\\\"");
128 rust.push_str(&format!(" debug_assert!({mapped},\n \"Contract {eq_name}: precondition violated — {esc}\");\n"));
129 count += 1;
130 }
131 rust.push_str(" }};\n}\n\n");
132 } else {
133 rust.push_str(&format!(
134 "/// Call at function entry: `contract_pre_{macro_name}!(input_expr)`\n"
135 ));
136 rust.push_str(&format!("macro_rules! contract_pre_{macro_name} {{\n"));
137 rust.push_str(" () => {{}};\n");
138 rust.push_str(" ($input:expr) => {{\n let _contract_input = &$input;\n");
139 for pre in pres {
140 let mut assertion = pre
142 .replace("input", "_contract_input")
143 .replace("x.", "_contract_input.")
144 .replace("x)", "_contract_input)");
145 if has_unbound_vars(&assertion, "_contract_input") {
148 let stripped = pre.trim_start_matches('!');
149 if let Some(dot) = stripped.find('.') {
150 let var = &stripped[..dot];
151 let method = &stripped[dot + 1..];
152 let safe_method = method.starts_with("len()")
154 || method.starts_with("iter()")
155 || method.starts_with("is_finite()");
156 if safe_method
157 && !var.is_empty()
158 && var.chars().all(|c| c.is_alphanumeric() || c == '_')
159 {
160 let mapped = pre.replace(var, "_contract_input");
161 if !has_unbound_vars(&mapped, "_contract_input") {
162 assertion = mapped;
163 }
164 }
165 }
166 }
167 if has_unbound_vars(&assertion, "_contract_input") {
169 continue;
170 }
171 let esc = pre.replace('"', "\\\"");
172 rust.push_str(&format!(" debug_assert!({assertion},\n \"Contract {eq_name}: precondition violated — {esc}\");\n"));
173 count += 1;
174 }
175 rust.push_str(" }};\n}\n\n");
176 }
177 count
178}
179
180fn emit_postcondition_macro(
182 rust: &mut String,
183 eq_name: &str,
184 macro_name: &str,
185 posts: &[String],
186) -> usize {
187 if posts.is_empty() {
188 return 0;
189 }
190 let mut count = 0;
191 rust.push_str(&format!("/// Postconditions for equation `{eq_name}`.\n"));
192 rust.push_str(&format!(
193 "/// Call before return: `contract_post_{macro_name}!(result_expr)`\n"
194 ));
195 rust.push_str(&format!("macro_rules! contract_post_{macro_name} {{\n"));
196 rust.push_str(" ($result:expr) => {{\n let _contract_result = &$result;\n");
197 for post in posts {
198 let fixed = if post.contains("result.") || post.contains("result)") {
201 post.replace("result", "_contract_result")
202 } else {
203 post.replace("result", "*_contract_result")
205 };
206 if has_unbound_vars(&fixed, "_contract_result") {
208 continue;
209 }
210 let esc = post.replace('"', "\\\"");
211 rust.push_str(&format!(" debug_assert!({fixed}, \"Contract {eq_name}: postcondition violated — {esc}\");\n"));
212 count += 1;
213 }
214 rust.push_str(" }};\n}\n\n");
215 count
216}
217
218fn emit_invariant_macro(
221 rust: &mut String,
222 eq_name: &str,
223 macro_name: &str,
224 invariants: &[String],
225) -> usize {
226 if invariants.is_empty() {
227 return 0;
228 }
229 let mut count = 0;
230 rust.push_str(&format!("/// Invariants for equation `{eq_name}`.\n"));
231 rust.push_str(&format!(
232 "/// Check after computation: `contract_inv_{macro_name}!(result_expr)`\n"
233 ));
234 rust.push_str(&format!("macro_rules! contract_inv_{macro_name} {{\n"));
235 rust.push_str(" () => {{}};\n");
236 rust.push_str(" ($result:expr) => {{\n let _contract_result = &$result;\n");
237 for inv in invariants {
238 let fixed = if inv.contains("result.") || inv.contains("result)") {
240 inv.replace("result", "_contract_result")
241 } else if inv.contains(">=")
242 || inv.contains("<=")
243 || inv.contains("==")
244 || inv.contains("> ")
245 || inv.contains("< ")
246 {
247 inv.replace("result", "*_contract_result")
248 } else {
249 continue; };
251 if has_unbound_vars(&fixed, "_contract_result") {
253 continue;
254 }
255 let esc = inv.replace('"', "\\\"");
256 rust.push_str(&format!(" debug_assert!({fixed}, \"Contract {eq_name}: invariant violated \u{2014} {esc}\");\n"));
257 count += 1;
258 }
259 rust.push_str(" }};\n}\n\n");
260 count
261}
262
263fn emit_combined_macro(
265 rust: &mut String,
266 eq_name: &str,
267 macro_name: &str,
268 pres: &[String],
269 posts: &[String],
270) {
271 if pres.is_empty() || posts.is_empty() {
272 return;
273 }
274 rust.push_str(&format!(
275 "/// Combined pre+post contract for equation `{eq_name}`.\n"
276 ));
277 rust.push_str(&format!("macro_rules! contract_{macro_name} {{\n"));
278 rust.push_str(" ($input:expr, $body:expr) => {{\n");
279 rust.push_str(&format!(" contract_pre_{macro_name}!($input);\n"));
280 rust.push_str(" let _contract_result = $body;\n");
281 rust.push_str(&format!(
282 " contract_post_{macro_name}!(_contract_result);\n"
283 ));
284 rust.push_str(" _contract_result\n");
285 rust.push_str(" }};\n}\n\n");
286}
287
288fn detect_primary_var(preconditions: &[String]) -> String {
291 for pre in preconditions {
292 if let Some(dot_pos) = pre.find('.') {
294 let candidate = &pre[..dot_pos];
295 if !candidate.is_empty()
297 && candidate.chars().all(|c| c.is_alphanumeric() || c == '_')
298 && candidate != "result"
299 {
300 return candidate.to_string();
301 }
302 }
303 }
304 "x".to_string() }
306
307fn has_unbound_vars(expr: &str, primary_var: &str) -> bool {
310 let safe_names = [
313 primary_var,
314 "_contract_input",
315 "true",
316 "false",
317 "f32",
318 "f64",
319 "usize",
320 "i32",
321 "i64",
322 ];
323 for token in expr.split(|c: char| "().&|!<>=+- */%,;{}[]".contains(c)) {
325 let token = token.trim();
326 if token.is_empty() || token.chars().next().is_some_and(|c| c.is_ascii_digit()) {
327 continue; }
329 if safe_names.contains(&token)
331 || token == "v"
332 || token == "id"
333 || token.starts_with("is_")
334 || token == "iter"
335 || token == "all"
336 || token == "any"
337 || token == "len"
338 || token == "abs"
339 || token == "sum"
340 {
341 continue;
342 }
343 if token.chars().all(|c| c.is_alphanumeric() || c == '_') && token.len() <= 20 {
345 return true;
346 }
347 }
348 false
349}
350
351pub fn generate_all(contract_dir: &Path) -> Vec<GeneratedContract> {
353 let mut yaml_paths = Vec::new();
354 collect_yaml_files(contract_dir, &mut yaml_paths);
355
356 let mut results = Vec::new();
357 for path in &yaml_paths {
358 let stem = path
359 .file_stem()
360 .and_then(|s| s.to_str())
361 .unwrap_or("unknown")
362 .to_string();
363
364 if let Ok(contract) = crate::schema::parse_contract(path) {
365 let generated = generate_from_contract(&stem, &contract);
366 if generated.precondition_count > 0
367 || generated.postcondition_count > 0
368 || generated.lean_theorem_count > 0
369 {
370 results.push(generated);
371 }
372 }
373 }
374
375 results.sort_by(|a, b| a.name.cmp(&b.name));
376 results
377}
378
379fn collect_yaml_files(dir: &Path, out: &mut Vec<std::path::PathBuf>) {
381 let Ok(entries) = std::fs::read_dir(dir) else {
382 return;
383 };
384 for entry in entries.flatten() {
385 let path = entry.path();
386 if path.is_dir() {
387 let dirname = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
388 if dirname == "kaizen" || dirname == "legacy" || dirname == "pipelines" {
389 continue;
390 }
391 collect_yaml_files(&path, out);
392 } else if path.extension().and_then(|e| e.to_str()) == Some("yaml")
393 && path.file_name().and_then(|n| n.to_str()) != Some("binding.yaml")
394 {
395 out.push(path);
396 }
397 }
398}
399
400pub fn write_rust_module(contracts: &[GeneratedContract], output: &Path) -> std::io::Result<()> {
402 let mut content = String::new();
403 content.push_str("// Auto-generated contract assertions from YAML — DO NOT EDIT.\n");
404 content.push_str("// Zero cost in release builds (debug_assert!).\n");
405 content.push_str("// Regenerate: pv codegen contracts/ -o src/generated_contracts.rs\n");
406 content.push_str(
407 "// Include: #[macro_use] #[allow(unused_macros)] mod generated_contracts;\n\n",
408 );
409
410 let mut total_pre = 0;
411 let mut total_post = 0;
412 let mut total_inv = 0;
413
414 for c in contracts {
415 content.push_str(&c.rust_assertions);
416 total_pre += c.precondition_count;
417 total_post += c.postcondition_count;
418 total_inv += c.invariant_count;
419 }
420
421 content.push_str(&format!(
422 "// Total: {} preconditions, {} postconditions, {} invariants from {} contracts\n",
423 total_pre,
424 total_post,
425 total_inv,
426 contracts.len()
427 ));
428
429 std::fs::write(output, content)
430}
431
432#[cfg(test)]
433#[path = "codegen_tests.rs"]
434mod tests;