Skip to main content

provable_contracts/scaffold/
mod.rs

1//! Scaffold generator — Phase 3 of the pipeline.
2//!
3//! Generates Rust trait definitions and failing test stubs
4//! from parsed YAML contracts.
5
6use crate::schema::Contract;
7
8/// Generate a Rust trait definition from a contract.
9///
10/// Each equation becomes a method. Each proof obligation
11/// becomes a doc-comment with INVARIANT/REQUIRES prefix.
12pub fn generate_trait(contract: &Contract) -> String {
13    let mut out = String::new();
14    let desc = &contract.metadata.description;
15
16    // Header
17    out.push_str(&format!(
18        "/// Contract: {} v{}\n",
19        desc, contract.metadata.version
20    ));
21    for r in &contract.metadata.references {
22        out.push_str(&format!("/// Paper: {r}\n"));
23    }
24    out.push_str("pub trait KernelContract {\n");
25
26    // One method per equation
27    for (name, eq) in &contract.equations {
28        out.push_str(&format!("    /// {}\n", eq.formula));
29        if let Some(ref domain) = eq.domain {
30            out.push_str(&format!("    /// Domain: {domain}\n"));
31        }
32        if let Some(ref codomain) = eq.codomain {
33            out.push_str(&format!("    /// Codomain: {codomain}\n"));
34        }
35        for inv in &eq.invariants {
36            out.push_str(&format!("    /// INVARIANT: {inv}\n"));
37        }
38        // Add proof obligations for this equation
39        for ob in &contract.proof_obligations {
40            out.push_str(&format!(
41                "    /// {} ({}): {}\n",
42                ob.obligation_type.to_string().to_uppercase(),
43                ob.property,
44                ob.formal.as_deref().unwrap_or("")
45            ));
46        }
47        out.push_str(&format!(
48            "    fn {name}(&self, input: &[f32], output: &mut [f32]);\n"
49        ));
50    }
51
52    out.push_str("}\n");
53    out
54}
55
56/// Generate a standalone, named contract trait from a YAML contract.
57///
58/// Unlike `generate_trait` (which produces a generic `KernelContract`),
59/// this generates a **named** trait specific to the contract (e.g.,
60/// `SoftmaxKernelV1`) with proper doc comments and equations as methods.
61///
62/// Consumer crates `impl` this trait. Missing method = compile error.
63/// Wrong signature = compile error. This is Layer 2 enforcement (§23).
64///
65/// # Arguments
66///
67/// * `contract` - Parsed YAML contract
68/// * `stem` - Contract stem (e.g., "softmax-kernel-v1")
69pub fn generate_standalone_trait(contract: &Contract, stem: &str) -> String {
70    let trait_name = stem_to_trait_name(stem);
71    let mut out = String::new();
72
73    // Module header
74    out.push_str(&format!(
75        "//! Auto-generated contract trait for `{stem}`.\n"
76    ));
77    out.push_str(&format!(
78        "//! Generated by: `pv scaffold --trait contracts/{stem}.yaml`\n"
79    ));
80    out.push_str("//! DO NOT EDIT — regenerate from YAML source.\n\n");
81    out.push_str("#![allow(clippy::doc_markdown)]\n\n");
82
83    // Trait doc
84    out.push_str(&format!(
85        "/// Contract trait for `{stem}` v{}.\n",
86        contract.metadata.version
87    ));
88    out.push_str(&format!("///\n/// {}\n", contract.metadata.description));
89    for r in &contract.metadata.references {
90        out.push_str(&format!("/// Reference: {r}\n"));
91    }
92    out.push_str("///\n");
93    out.push_str(&format!(
94        "/// Implementors must provide all {} equation(s).\n",
95        contract.equations.len()
96    ));
97    out.push_str("/// Missing method = compile error. Wrong signature = compile error.\n");
98
99    out.push_str(&format!("pub trait {trait_name} {{\n"));
100
101    // One method per equation
102    let eq_count = contract.equations.len();
103    for (i, (name, eq)) in contract.equations.iter().enumerate() {
104        out.push_str(&format!("    /// `{name}`: {}\n", eq.formula));
105        if let Some(ref domain) = eq.domain {
106            out.push_str(&format!("    /// Domain: {domain}\n"));
107        }
108        if let Some(ref codomain) = eq.codomain {
109            out.push_str(&format!("    /// Codomain: {codomain}\n"));
110        }
111        for inv in &eq.invariants {
112            out.push_str(&format!("    /// Invariant: {inv}\n"));
113        }
114        // Use equation name as method name, sanitized
115        let method_name = name.replace('-', "_").to_lowercase();
116        let params = domain_to_params(eq.domain.as_deref());
117        out.push_str(&format!("    fn {method_name}({params}) -> Vec<f32>;\n"));
118        // Blank line between methods, but not after the last one
119        if i + 1 < eq_count {
120            out.push('\n');
121        }
122    }
123
124    out.push_str("}\n");
125    out
126}
127
128/// Parse a YAML domain string to generate Rust method parameters.
129///
130/// Examples:
131/// - `"x ∈ ℝ^n"` → `"&self, x: &[f32]"`
132/// - `"Q ∈ ℝ^{n×d_k}, K ∈ ℝ^{m×d_k}, V ∈ ℝ^{m×d_v}"` → `"&self, q: &[f32], k: &[f32], v: &[f32]"`
133/// - `"A ∈ ℝ^{m×p}, B ∈ ℝ^{p×n}"` → `"&self, a: &[f32], b: &[f32]"`
134/// - `None` → `"&self, input: &[f32]"`
135fn domain_to_params(domain: Option<&str>) -> String {
136    let Some(domain) = domain else {
137        return "&self, input: &[f32]".to_string();
138    };
139
140    let mut params = Vec::new();
141    for segment in domain.split(',') {
142        let segment = segment.trim();
143
144        // Extract variable name: text BEFORE "∈" or " in "
145        let var = if let Some((left, _)) = segment.split_once('∈') {
146            left.trim()
147        } else if let Some((left, _)) = segment.split_once(" in ") {
148            left.trim()
149        } else {
150            continue; // No separator — skip (e.g., "beta1 = 0.9")
151        };
152
153        if var.is_empty() || var.contains('(') || var.contains('>') || var.contains('<') {
154            continue;
155        }
156
157        // Clean: lowercase, keep only ascii alphanumeric + underscore
158        let clean: String = var
159            .chars()
160            .filter(|c| c.is_ascii_alphanumeric() || *c == '_')
161            .collect::<String>()
162            .to_lowercase();
163
164        // Filter out non-variable names
165        if clean.is_empty()
166            || clean.len() > 20
167            || clean.starts_with("num")
168            || clean.starts_with("beta")
169            || clean.starts_with("eps")
170            || clean.chars().next().unwrap_or('0').is_ascii_digit()
171        {
172            continue;
173        }
174
175        // Scalar vs array: scalar if domain is ℝ without exponent (no ^ or ×)
176        let is_scalar = segment.contains('ℝ') && !segment.contains('^') && !segment.contains('×');
177        let rust_type = if is_scalar { "f32" } else { "&[f32]" };
178        params.push(format!("{clean}: {rust_type}"));
179    }
180
181    if params.is_empty() {
182        "&self, input: &[f32]".to_string()
183    } else {
184        format!("&self, {}", params.join(", "))
185    }
186}
187
188#[cfg(test)]
189mod domain_tests {
190    use super::domain_to_params;
191
192    #[test]
193    fn single_vector() {
194        assert_eq!(domain_to_params(Some("x ∈ ℝ^n")), "&self, x: &[f32]");
195    }
196
197    #[test]
198    fn qkv_attention() {
199        let result = domain_to_params(Some("Q ∈ ℝ^{n×d_k}, K ∈ ℝ^{m×d_k}, V ∈ ℝ^{m×d_v}"));
200        assert_eq!(result, "&self, q: &[f32], k: &[f32], v: &[f32]");
201    }
202
203    #[test]
204    fn matmul_ab() {
205        let result = domain_to_params(Some("A ∈ ℝ^{m×p}, B ∈ ℝ^{p×n}"));
206        assert_eq!(result, "&self, a: &[f32], b: &[f32]");
207    }
208
209    #[test]
210    fn rope_with_position() {
211        let result = domain_to_params(Some("x ∈ ℝ^d, m ∈ ℕ, θ_k = 10000^(-2k/d)"));
212        assert_eq!(result, "&self, x: &[f32], m: &[f32]");
213    }
214
215    #[test]
216    fn adamw_filters_scalars() {
217        let result = domain_to_params(Some("g_t in R^d, m_0 = 0, beta1 in (0, 1)"));
218        assert_eq!(result, "&self, g_t: &[f32]");
219    }
220
221    #[test]
222    fn none_domain() {
223        assert_eq!(domain_to_params(None), "&self, input: &[f32]");
224    }
225
226    #[test]
227    fn empty_domain() {
228        assert_eq!(domain_to_params(Some("")), "&self, input: &[f32]");
229    }
230}
231
232/// Convert a contract stem to a `PascalCase` trait name.
233///
234/// `softmax-kernel-v1` becomes `SoftmaxKernelV1`
235fn stem_to_trait_name(stem: &str) -> String {
236    stem.split('-')
237        .map(|part| {
238            let mut chars = part.chars();
239            match chars.next() {
240                Some(c) => {
241                    let upper: String = c.to_uppercase().collect();
242                    format!("{upper}{}", chars.as_str())
243                }
244                None => String::new(),
245            }
246        })
247        .collect()
248}
249
250/// Generate failing contract test stubs from a contract.
251///
252/// Each falsification test becomes a `#[test]` with `todo!()`.
253pub fn generate_contract_tests(contract: &Contract) -> String {
254    let mut out = String::new();
255
256    out.push_str("#[cfg(test)]\nmod contract_tests {\n");
257    out.push_str("    use super::*;\n\n");
258
259    for test in &contract.falsification_tests {
260        out.push_str(&format!("    /// {}: {}\n", test.id, test.rule));
261        out.push_str(&format!("    /// Prediction: {}\n", test.prediction));
262        out.push_str(&format!("    /// If fails: {}\n", test.if_fails));
263        let fn_name = test.id.to_lowercase().replace('-', "_");
264        out.push_str(&format!("    #[test]\n    fn {fn_name}() {{\n"));
265        out.push_str(&format!(
266            "        todo!(\"Implementation not yet written — \
267                     {} MUST fail\")\n",
268            test.id
269        ));
270        out.push_str("    }\n\n");
271    }
272
273    out.push_str("}\n");
274    out
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use crate::schema::parse_contract_str;
281
282    fn sample_contract() -> Contract {
283        parse_contract_str(
284            r#"
285metadata:
286  version: "1.0.0"
287  description: "Test kernel"
288  references:
289    - "Paper (2024)"
290equations:
291  softmax:
292    formula: "σ(x) = exp(x-max) / Σexp(x-max)"
293    domain: "ℝ^n"
294    codomain: "(0,1)^n"
295    invariants:
296      - "sum(output) = 1.0"
297proof_obligations:
298  - type: invariant
299    property: "normalization"
300    formal: "|sum(σ(x)) - 1.0| < ε"
301falsification_tests:
302  - id: FALSIFY-SM-001
303    rule: "normalization"
304    prediction: "sum(output) ≈ 1.0"
305    if_fails: "missing max subtraction"
306  - id: FALSIFY-SM-002
307    rule: "positivity"
308    prediction: "output > 0"
309    if_fails: "exp underflow"
310"#,
311        )
312        .unwrap()
313    }
314
315    #[test]
316    fn generate_trait_includes_equations() {
317        let contract = sample_contract();
318        let code = generate_trait(&contract);
319        assert!(code.contains("pub trait KernelContract"));
320        assert!(code.contains("fn softmax"));
321        assert!(code.contains("INVARIANT: sum(output) = 1.0"));
322    }
323
324    #[test]
325    fn generate_tests_creates_stubs() {
326        let contract = sample_contract();
327        let code = generate_contract_tests(&contract);
328        assert!(code.contains("fn falsify_sm_001()"));
329        assert!(code.contains("fn falsify_sm_002()"));
330        assert!(code.contains("todo!"));
331    }
332
333    #[test]
334    fn generate_tests_includes_predictions() {
335        let contract = sample_contract();
336        let code = generate_contract_tests(&contract);
337        assert!(code.contains("sum(output) ≈ 1.0"));
338        assert!(code.contains("missing max subtraction"));
339    }
340
341    #[test]
342    fn generate_trait_includes_paper_refs() {
343        let contract = sample_contract();
344        let code = generate_trait(&contract);
345        assert!(code.contains("Paper: Paper (2024)"));
346    }
347
348    #[test]
349    fn generate_trait_includes_domain_codomain() {
350        let contract = sample_contract();
351        let code = generate_trait(&contract);
352        assert!(code.contains("Domain:"));
353        assert!(code.contains("Codomain:"));
354    }
355
356    #[test]
357    fn generate_trait_includes_proof_obligation() {
358        let contract = sample_contract();
359        let code = generate_trait(&contract);
360        assert!(code.contains("INVARIANT"));
361        assert!(code.contains("normalization"));
362    }
363
364    #[test]
365    fn stem_to_trait_name_basic() {
366        assert_eq!(stem_to_trait_name("softmax-kernel-v1"), "SoftmaxKernelV1");
367        assert_eq!(stem_to_trait_name("gelu-kernel-v1"), "GeluKernelV1");
368        assert_eq!(stem_to_trait_name("a"), "A");
369        assert_eq!(stem_to_trait_name(""), "");
370    }
371
372    #[test]
373    fn generate_standalone_trait_header() {
374        let contract = sample_contract();
375        let code = generate_standalone_trait(&contract, "softmax-kernel-v1");
376        assert!(code.contains("pub trait SoftmaxKernelV1"));
377        assert!(code.contains("Auto-generated contract trait"));
378        assert!(code.contains("DO NOT EDIT"));
379        assert!(code.contains("#![allow(clippy::doc_markdown)]"));
380    }
381
382    #[test]
383    fn generate_standalone_trait_methods() {
384        let contract = sample_contract();
385        let code = generate_standalone_trait(&contract, "softmax-kernel-v1");
386        assert!(code.contains("fn softmax("));
387        assert!(code.contains("-> Vec<f32>"));
388    }
389
390    #[test]
391    fn generate_standalone_trait_invariants() {
392        let contract = sample_contract();
393        let code = generate_standalone_trait(&contract, "softmax-kernel-v1");
394        assert!(code.contains("Invariant: sum(output) = 1.0"));
395    }
396
397    #[test]
398    fn generate_standalone_trait_references() {
399        let contract = sample_contract();
400        let code = generate_standalone_trait(&contract, "softmax-kernel-v1");
401        assert!(code.contains("Reference: Paper (2024)"));
402    }
403
404    #[test]
405    fn generate_standalone_trait_implementor_note() {
406        let contract = sample_contract();
407        let code = generate_standalone_trait(&contract, "test-v1");
408        assert!(code.contains("Implementors must provide all 1 equation(s)"));
409        assert!(code.contains("Missing method = compile error"));
410    }
411
412    #[test]
413    fn generate_contract_tests_all_ids() {
414        let contract = sample_contract();
415        let code = generate_contract_tests(&contract);
416        assert!(code.contains("#[cfg(test)]"));
417        assert!(code.contains("mod contract_tests"));
418        assert!(code.contains("use super::*;"));
419        assert!(code.contains("fn falsify_sm_001()"));
420        assert!(code.contains("fn falsify_sm_002()"));
421    }
422
423    fn multi_equation_contract() -> Contract {
424        parse_contract_str(
425            r#"
426metadata:
427  version: "2.0.0"
428  description: "Multi-equation kernel"
429  references:
430    - "Ref A"
431    - "Ref B"
432equations:
433  alpha:
434    formula: "alpha(x) = x^2"
435    domain: "x ∈ ℝ^n"
436    codomain: "ℝ^n"
437    invariants:
438      - "output >= 0"
439  beta:
440    formula: "beta(x) = 2x"
441    domain: "x ∈ ℝ^n"
442    invariants:
443      - "output proportional to input"
444proof_obligations:
445  - type: bound
446    property: "non-negativity"
447    formal: "∀x: alpha(x) ≥ 0"
448falsification_tests:
449  - id: FALSIFY-MQ-001
450    rule: "non-neg"
451    prediction: "alpha >= 0"
452    if_fails: "squared value is negative"
453"#,
454        )
455        .unwrap()
456    }
457
458    #[test]
459    fn generate_trait_multiple_equations() {
460        let contract = multi_equation_contract();
461        let code = generate_trait(&contract);
462        assert!(code.contains("fn alpha("));
463        assert!(code.contains("fn beta("));
464        assert!(code.contains("BOUND"));
465    }
466
467    #[test]
468    fn generate_standalone_multiple_equations() {
469        let contract = multi_equation_contract();
470        let code = generate_standalone_trait(&contract, "multi-eq-v1");
471        assert!(code.contains("pub trait MultiEqV1"));
472        assert!(code.contains("fn alpha("));
473        assert!(code.contains("fn beta("));
474        assert!(code.contains("2 equation(s)"));
475    }
476
477    #[test]
478    fn generate_trait_version_in_header() {
479        let contract = sample_contract();
480        let code = generate_trait(&contract);
481        assert!(code.contains("v1.0.0"));
482    }
483}