provable_contracts/scaffold/
mod.rs1use crate::schema::Contract;
7
8pub fn generate_trait(contract: &Contract) -> String {
13 let mut out = String::new();
14 let desc = &contract.metadata.description;
15
16 out.push_str(&format!(
18 "/// Contract: {} v{}\n",
19 desc, contract.metadata.version
20 ));
21 for r in &contract.metadata.references {
22 out.push_str(&format!("/// Paper: {r}\n"));
23 }
24 out.push_str("pub trait KernelContract {\n");
25
26 for (name, eq) in &contract.equations {
28 out.push_str(&format!(" /// {}\n", eq.formula));
29 if let Some(ref domain) = eq.domain {
30 out.push_str(&format!(" /// Domain: {domain}\n"));
31 }
32 if let Some(ref codomain) = eq.codomain {
33 out.push_str(&format!(" /// Codomain: {codomain}\n"));
34 }
35 for inv in &eq.invariants {
36 out.push_str(&format!(" /// INVARIANT: {inv}\n"));
37 }
38 for ob in &contract.proof_obligations {
40 out.push_str(&format!(
41 " /// {} ({}): {}\n",
42 ob.obligation_type.to_string().to_uppercase(),
43 ob.property,
44 ob.formal.as_deref().unwrap_or("")
45 ));
46 }
47 out.push_str(&format!(
48 " fn {name}(&self, input: &[f32], output: &mut [f32]);\n"
49 ));
50 }
51
52 out.push_str("}\n");
53 out
54}
55
56pub fn generate_standalone_trait(contract: &Contract, stem: &str) -> String {
70 let trait_name = stem_to_trait_name(stem);
71 let mut out = String::new();
72
73 out.push_str(&format!(
75 "//! Auto-generated contract trait for `{stem}`.\n"
76 ));
77 out.push_str(&format!(
78 "//! Generated by: `pv scaffold --trait contracts/{stem}.yaml`\n"
79 ));
80 out.push_str("//! DO NOT EDIT — regenerate from YAML source.\n\n");
81 out.push_str("#![allow(clippy::doc_markdown)]\n\n");
82
83 out.push_str(&format!(
85 "/// Contract trait for `{stem}` v{}.\n",
86 contract.metadata.version
87 ));
88 out.push_str(&format!("///\n/// {}\n", contract.metadata.description));
89 for r in &contract.metadata.references {
90 out.push_str(&format!("/// Reference: {r}\n"));
91 }
92 out.push_str("///\n");
93 out.push_str(&format!(
94 "/// Implementors must provide all {} equation(s).\n",
95 contract.equations.len()
96 ));
97 out.push_str("/// Missing method = compile error. Wrong signature = compile error.\n");
98
99 out.push_str(&format!("pub trait {trait_name} {{\n"));
100
101 let eq_count = contract.equations.len();
103 for (i, (name, eq)) in contract.equations.iter().enumerate() {
104 out.push_str(&format!(" /// `{name}`: {}\n", eq.formula));
105 if let Some(ref domain) = eq.domain {
106 out.push_str(&format!(" /// Domain: {domain}\n"));
107 }
108 if let Some(ref codomain) = eq.codomain {
109 out.push_str(&format!(" /// Codomain: {codomain}\n"));
110 }
111 for inv in &eq.invariants {
112 out.push_str(&format!(" /// Invariant: {inv}\n"));
113 }
114 let method_name = name.replace('-', "_").to_lowercase();
116 let params = domain_to_params(eq.domain.as_deref());
117 out.push_str(&format!(" fn {method_name}({params}) -> Vec<f32>;\n"));
118 if i + 1 < eq_count {
120 out.push('\n');
121 }
122 }
123
124 out.push_str("}\n");
125 out
126}
127
128fn domain_to_params(domain: Option<&str>) -> String {
136 let Some(domain) = domain else {
137 return "&self, input: &[f32]".to_string();
138 };
139
140 let mut params = Vec::new();
141 for segment in domain.split(',') {
142 let segment = segment.trim();
143
144 let var = if let Some((left, _)) = segment.split_once('∈') {
146 left.trim()
147 } else if let Some((left, _)) = segment.split_once(" in ") {
148 left.trim()
149 } else {
150 continue; };
152
153 if var.is_empty() || var.contains('(') || var.contains('>') || var.contains('<') {
154 continue;
155 }
156
157 let clean: String = var
159 .chars()
160 .filter(|c| c.is_ascii_alphanumeric() || *c == '_')
161 .collect::<String>()
162 .to_lowercase();
163
164 if clean.is_empty()
166 || clean.len() > 20
167 || clean.starts_with("num")
168 || clean.starts_with("beta")
169 || clean.starts_with("eps")
170 || clean.chars().next().unwrap_or('0').is_ascii_digit()
171 {
172 continue;
173 }
174
175 let is_scalar = segment.contains('ℝ') && !segment.contains('^') && !segment.contains('×');
177 let rust_type = if is_scalar { "f32" } else { "&[f32]" };
178 params.push(format!("{clean}: {rust_type}"));
179 }
180
181 if params.is_empty() {
182 "&self, input: &[f32]".to_string()
183 } else {
184 format!("&self, {}", params.join(", "))
185 }
186}
187
188#[cfg(test)]
189mod domain_tests {
190 use super::domain_to_params;
191
192 #[test]
193 fn single_vector() {
194 assert_eq!(domain_to_params(Some("x ∈ ℝ^n")), "&self, x: &[f32]");
195 }
196
197 #[test]
198 fn qkv_attention() {
199 let result = domain_to_params(Some("Q ∈ ℝ^{n×d_k}, K ∈ ℝ^{m×d_k}, V ∈ ℝ^{m×d_v}"));
200 assert_eq!(result, "&self, q: &[f32], k: &[f32], v: &[f32]");
201 }
202
203 #[test]
204 fn matmul_ab() {
205 let result = domain_to_params(Some("A ∈ ℝ^{m×p}, B ∈ ℝ^{p×n}"));
206 assert_eq!(result, "&self, a: &[f32], b: &[f32]");
207 }
208
209 #[test]
210 fn rope_with_position() {
211 let result = domain_to_params(Some("x ∈ ℝ^d, m ∈ ℕ, θ_k = 10000^(-2k/d)"));
212 assert_eq!(result, "&self, x: &[f32], m: &[f32]");
213 }
214
215 #[test]
216 fn adamw_filters_scalars() {
217 let result = domain_to_params(Some("g_t in R^d, m_0 = 0, beta1 in (0, 1)"));
218 assert_eq!(result, "&self, g_t: &[f32]");
219 }
220
221 #[test]
222 fn none_domain() {
223 assert_eq!(domain_to_params(None), "&self, input: &[f32]");
224 }
225
226 #[test]
227 fn empty_domain() {
228 assert_eq!(domain_to_params(Some("")), "&self, input: &[f32]");
229 }
230}
231
232fn stem_to_trait_name(stem: &str) -> String {
236 stem.split('-')
237 .map(|part| {
238 let mut chars = part.chars();
239 match chars.next() {
240 Some(c) => {
241 let upper: String = c.to_uppercase().collect();
242 format!("{upper}{}", chars.as_str())
243 }
244 None => String::new(),
245 }
246 })
247 .collect()
248}
249
250pub fn generate_contract_tests(contract: &Contract) -> String {
254 let mut out = String::new();
255
256 out.push_str("#[cfg(test)]\nmod contract_tests {\n");
257 out.push_str(" use super::*;\n\n");
258
259 for test in &contract.falsification_tests {
260 out.push_str(&format!(" /// {}: {}\n", test.id, test.rule));
261 out.push_str(&format!(" /// Prediction: {}\n", test.prediction));
262 out.push_str(&format!(" /// If fails: {}\n", test.if_fails));
263 let fn_name = test.id.to_lowercase().replace('-', "_");
264 out.push_str(&format!(" #[test]\n fn {fn_name}() {{\n"));
265 out.push_str(&format!(
266 " todo!(\"Implementation not yet written — \
267 {} MUST fail\")\n",
268 test.id
269 ));
270 out.push_str(" }\n\n");
271 }
272
273 out.push_str("}\n");
274 out
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use crate::schema::parse_contract_str;
281
282 fn sample_contract() -> Contract {
283 parse_contract_str(
284 r#"
285metadata:
286 version: "1.0.0"
287 description: "Test kernel"
288 references:
289 - "Paper (2024)"
290equations:
291 softmax:
292 formula: "σ(x) = exp(x-max) / Σexp(x-max)"
293 domain: "ℝ^n"
294 codomain: "(0,1)^n"
295 invariants:
296 - "sum(output) = 1.0"
297proof_obligations:
298 - type: invariant
299 property: "normalization"
300 formal: "|sum(σ(x)) - 1.0| < ε"
301falsification_tests:
302 - id: FALSIFY-SM-001
303 rule: "normalization"
304 prediction: "sum(output) ≈ 1.0"
305 if_fails: "missing max subtraction"
306 - id: FALSIFY-SM-002
307 rule: "positivity"
308 prediction: "output > 0"
309 if_fails: "exp underflow"
310"#,
311 )
312 .unwrap()
313 }
314
315 #[test]
316 fn generate_trait_includes_equations() {
317 let contract = sample_contract();
318 let code = generate_trait(&contract);
319 assert!(code.contains("pub trait KernelContract"));
320 assert!(code.contains("fn softmax"));
321 assert!(code.contains("INVARIANT: sum(output) = 1.0"));
322 }
323
324 #[test]
325 fn generate_tests_creates_stubs() {
326 let contract = sample_contract();
327 let code = generate_contract_tests(&contract);
328 assert!(code.contains("fn falsify_sm_001()"));
329 assert!(code.contains("fn falsify_sm_002()"));
330 assert!(code.contains("todo!"));
331 }
332
333 #[test]
334 fn generate_tests_includes_predictions() {
335 let contract = sample_contract();
336 let code = generate_contract_tests(&contract);
337 assert!(code.contains("sum(output) ≈ 1.0"));
338 assert!(code.contains("missing max subtraction"));
339 }
340
341 #[test]
342 fn generate_trait_includes_paper_refs() {
343 let contract = sample_contract();
344 let code = generate_trait(&contract);
345 assert!(code.contains("Paper: Paper (2024)"));
346 }
347
348 #[test]
349 fn generate_trait_includes_domain_codomain() {
350 let contract = sample_contract();
351 let code = generate_trait(&contract);
352 assert!(code.contains("Domain:"));
353 assert!(code.contains("Codomain:"));
354 }
355
356 #[test]
357 fn generate_trait_includes_proof_obligation() {
358 let contract = sample_contract();
359 let code = generate_trait(&contract);
360 assert!(code.contains("INVARIANT"));
361 assert!(code.contains("normalization"));
362 }
363
364 #[test]
365 fn stem_to_trait_name_basic() {
366 assert_eq!(stem_to_trait_name("softmax-kernel-v1"), "SoftmaxKernelV1");
367 assert_eq!(stem_to_trait_name("gelu-kernel-v1"), "GeluKernelV1");
368 assert_eq!(stem_to_trait_name("a"), "A");
369 assert_eq!(stem_to_trait_name(""), "");
370 }
371
372 #[test]
373 fn generate_standalone_trait_header() {
374 let contract = sample_contract();
375 let code = generate_standalone_trait(&contract, "softmax-kernel-v1");
376 assert!(code.contains("pub trait SoftmaxKernelV1"));
377 assert!(code.contains("Auto-generated contract trait"));
378 assert!(code.contains("DO NOT EDIT"));
379 assert!(code.contains("#![allow(clippy::doc_markdown)]"));
380 }
381
382 #[test]
383 fn generate_standalone_trait_methods() {
384 let contract = sample_contract();
385 let code = generate_standalone_trait(&contract, "softmax-kernel-v1");
386 assert!(code.contains("fn softmax("));
387 assert!(code.contains("-> Vec<f32>"));
388 }
389
390 #[test]
391 fn generate_standalone_trait_invariants() {
392 let contract = sample_contract();
393 let code = generate_standalone_trait(&contract, "softmax-kernel-v1");
394 assert!(code.contains("Invariant: sum(output) = 1.0"));
395 }
396
397 #[test]
398 fn generate_standalone_trait_references() {
399 let contract = sample_contract();
400 let code = generate_standalone_trait(&contract, "softmax-kernel-v1");
401 assert!(code.contains("Reference: Paper (2024)"));
402 }
403
404 #[test]
405 fn generate_standalone_trait_implementor_note() {
406 let contract = sample_contract();
407 let code = generate_standalone_trait(&contract, "test-v1");
408 assert!(code.contains("Implementors must provide all 1 equation(s)"));
409 assert!(code.contains("Missing method = compile error"));
410 }
411
412 #[test]
413 fn generate_contract_tests_all_ids() {
414 let contract = sample_contract();
415 let code = generate_contract_tests(&contract);
416 assert!(code.contains("#[cfg(test)]"));
417 assert!(code.contains("mod contract_tests"));
418 assert!(code.contains("use super::*;"));
419 assert!(code.contains("fn falsify_sm_001()"));
420 assert!(code.contains("fn falsify_sm_002()"));
421 }
422
423 fn multi_equation_contract() -> Contract {
424 parse_contract_str(
425 r#"
426metadata:
427 version: "2.0.0"
428 description: "Multi-equation kernel"
429 references:
430 - "Ref A"
431 - "Ref B"
432equations:
433 alpha:
434 formula: "alpha(x) = x^2"
435 domain: "x ∈ ℝ^n"
436 codomain: "ℝ^n"
437 invariants:
438 - "output >= 0"
439 beta:
440 formula: "beta(x) = 2x"
441 domain: "x ∈ ℝ^n"
442 invariants:
443 - "output proportional to input"
444proof_obligations:
445 - type: bound
446 property: "non-negativity"
447 formal: "∀x: alpha(x) ≥ 0"
448falsification_tests:
449 - id: FALSIFY-MQ-001
450 rule: "non-neg"
451 prediction: "alpha >= 0"
452 if_fails: "squared value is negative"
453"#,
454 )
455 .unwrap()
456 }
457
458 #[test]
459 fn generate_trait_multiple_equations() {
460 let contract = multi_equation_contract();
461 let code = generate_trait(&contract);
462 assert!(code.contains("fn alpha("));
463 assert!(code.contains("fn beta("));
464 assert!(code.contains("BOUND"));
465 }
466
467 #[test]
468 fn generate_standalone_multiple_equations() {
469 let contract = multi_equation_contract();
470 let code = generate_standalone_trait(&contract, "multi-eq-v1");
471 assert!(code.contains("pub trait MultiEqV1"));
472 assert!(code.contains("fn alpha("));
473 assert!(code.contains("fn beta("));
474 assert!(code.contains("2 equation(s)"));
475 }
476
477 #[test]
478 fn generate_trait_version_in_header() {
479 let contract = sample_contract();
480 let code = generate_trait(&contract);
481 assert!(code.contains("v1.0.0"));
482 }
483}