use crate::schema::Contract;
pub fn generate_trait(contract: &Contract) -> String {
let mut out = String::new();
let desc = &contract.metadata.description;
out.push_str(&format!(
"/// Contract: {} v{}\n",
desc, contract.metadata.version
));
for r in &contract.metadata.references {
out.push_str(&format!("/// Paper: {r}\n"));
}
out.push_str("pub trait KernelContract {\n");
for (name, eq) in &contract.equations {
out.push_str(&format!(" /// {}\n", eq.formula));
if let Some(ref domain) = eq.domain {
out.push_str(&format!(" /// Domain: {domain}\n"));
}
if let Some(ref codomain) = eq.codomain {
out.push_str(&format!(" /// Codomain: {codomain}\n"));
}
for inv in &eq.invariants {
out.push_str(&format!(" /// INVARIANT: {inv}\n"));
}
for ob in &contract.proof_obligations {
out.push_str(&format!(
" /// {} ({}): {}\n",
ob.obligation_type.to_string().to_uppercase(),
ob.property,
ob.formal.as_deref().unwrap_or("")
));
}
out.push_str(&format!(
" fn {name}(&self, input: &[f32], output: &mut [f32]);\n"
));
}
out.push_str("}\n");
out
}
pub fn generate_standalone_trait(contract: &Contract, stem: &str) -> String {
let trait_name = stem_to_trait_name(stem);
let mut out = String::new();
out.push_str(&format!(
"//! Auto-generated contract trait for `{stem}`.\n"
));
out.push_str(&format!(
"//! Generated by: `pv scaffold --trait contracts/{stem}.yaml`\n"
));
out.push_str("//! DO NOT EDIT — regenerate from YAML source.\n\n");
out.push_str("#![allow(clippy::doc_markdown)]\n\n");
out.push_str(&format!(
"/// Contract trait for `{stem}` v{}.\n",
contract.metadata.version
));
out.push_str(&format!("///\n/// {}\n", contract.metadata.description));
for r in &contract.metadata.references {
out.push_str(&format!("/// Reference: {r}\n"));
}
out.push_str("///\n");
out.push_str(&format!(
"/// Implementors must provide all {} equation(s).\n",
contract.equations.len()
));
out.push_str("/// Missing method = compile error. Wrong signature = compile error.\n");
out.push_str(&format!("pub trait {trait_name} {{\n"));
let eq_count = contract.equations.len();
for (i, (name, eq)) in contract.equations.iter().enumerate() {
out.push_str(&format!(" /// `{name}`: {}\n", eq.formula));
if let Some(ref domain) = eq.domain {
out.push_str(&format!(" /// Domain: {domain}\n"));
}
if let Some(ref codomain) = eq.codomain {
out.push_str(&format!(" /// Codomain: {codomain}\n"));
}
for inv in &eq.invariants {
out.push_str(&format!(" /// Invariant: {inv}\n"));
}
let method_name = name.replace('-', "_").to_lowercase();
let params = domain_to_params(eq.domain.as_deref());
out.push_str(&format!(" fn {method_name}({params}) -> Vec<f32>;\n"));
if i + 1 < eq_count {
out.push('\n');
}
}
out.push_str("}\n");
out
}
fn domain_to_params(domain: Option<&str>) -> String {
let Some(domain) = domain else {
return "&self, input: &[f32]".to_string();
};
let mut params = Vec::new();
for segment in domain.split(',') {
let segment = segment.trim();
let var = if let Some((left, _)) = segment.split_once('∈') {
left.trim()
} else if let Some((left, _)) = segment.split_once(" in ") {
left.trim()
} else {
continue; };
if var.is_empty() || var.contains('(') || var.contains('>') || var.contains('<') {
continue;
}
let clean: String = var
.chars()
.filter(|c| c.is_ascii_alphanumeric() || *c == '_')
.collect::<String>()
.to_lowercase();
if clean.is_empty()
|| clean.len() > 20
|| clean.starts_with("num")
|| clean.starts_with("beta")
|| clean.starts_with("eps")
|| clean.chars().next().unwrap_or('0').is_ascii_digit()
{
continue;
}
let is_scalar = segment.contains('ℝ') && !segment.contains('^') && !segment.contains('×');
let rust_type = if is_scalar { "f32" } else { "&[f32]" };
params.push(format!("{clean}: {rust_type}"));
}
if params.is_empty() {
"&self, input: &[f32]".to_string()
} else {
format!("&self, {}", params.join(", "))
}
}
#[cfg(test)]
mod domain_tests {
use super::domain_to_params;
#[test]
fn single_vector() {
assert_eq!(domain_to_params(Some("x ∈ ℝ^n")), "&self, x: &[f32]");
}
#[test]
fn qkv_attention() {
let result = domain_to_params(Some("Q ∈ ℝ^{n×d_k}, K ∈ ℝ^{m×d_k}, V ∈ ℝ^{m×d_v}"));
assert_eq!(result, "&self, q: &[f32], k: &[f32], v: &[f32]");
}
#[test]
fn matmul_ab() {
let result = domain_to_params(Some("A ∈ ℝ^{m×p}, B ∈ ℝ^{p×n}"));
assert_eq!(result, "&self, a: &[f32], b: &[f32]");
}
#[test]
fn rope_with_position() {
let result = domain_to_params(Some("x ∈ ℝ^d, m ∈ ℕ, θ_k = 10000^(-2k/d)"));
assert_eq!(result, "&self, x: &[f32], m: &[f32]");
}
#[test]
fn adamw_filters_scalars() {
let result = domain_to_params(Some("g_t in R^d, m_0 = 0, beta1 in (0, 1)"));
assert_eq!(result, "&self, g_t: &[f32]");
}
#[test]
fn none_domain() {
assert_eq!(domain_to_params(None), "&self, input: &[f32]");
}
#[test]
fn empty_domain() {
assert_eq!(domain_to_params(Some("")), "&self, input: &[f32]");
}
}
fn stem_to_trait_name(stem: &str) -> String {
stem.split('-')
.map(|part| {
let mut chars = part.chars();
match chars.next() {
Some(c) => {
let upper: String = c.to_uppercase().collect();
format!("{upper}{}", chars.as_str())
}
None => String::new(),
}
})
.collect()
}
pub fn generate_contract_tests(contract: &Contract) -> String {
let mut out = String::new();
out.push_str("#[cfg(test)]\nmod contract_tests {\n");
out.push_str(" use super::*;\n\n");
for test in &contract.falsification_tests {
out.push_str(&format!(" /// {}: {}\n", test.id, test.rule));
out.push_str(&format!(" /// Prediction: {}\n", test.prediction));
out.push_str(&format!(" /// If fails: {}\n", test.if_fails));
let fn_name = test.id.to_lowercase().replace('-', "_");
out.push_str(&format!(" #[test]\n fn {fn_name}() {{\n"));
out.push_str(&format!(
" todo!(\"Implementation not yet written — \
{} MUST fail\")\n",
test.id
));
out.push_str(" }\n\n");
}
out.push_str("}\n");
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::parse_contract_str;
fn sample_contract() -> Contract {
parse_contract_str(
r#"
metadata:
version: "1.0.0"
description: "Test kernel"
references:
- "Paper (2024)"
equations:
softmax:
formula: "σ(x) = exp(x-max) / Σexp(x-max)"
domain: "ℝ^n"
codomain: "(0,1)^n"
invariants:
- "sum(output) = 1.0"
proof_obligations:
- type: invariant
property: "normalization"
formal: "|sum(σ(x)) - 1.0| < ε"
falsification_tests:
- id: FALSIFY-SM-001
rule: "normalization"
prediction: "sum(output) ≈ 1.0"
if_fails: "missing max subtraction"
- id: FALSIFY-SM-002
rule: "positivity"
prediction: "output > 0"
if_fails: "exp underflow"
"#,
)
.unwrap()
}
#[test]
fn generate_trait_includes_equations() {
let contract = sample_contract();
let code = generate_trait(&contract);
assert!(code.contains("pub trait KernelContract"));
assert!(code.contains("fn softmax"));
assert!(code.contains("INVARIANT: sum(output) = 1.0"));
}
#[test]
fn generate_tests_creates_stubs() {
let contract = sample_contract();
let code = generate_contract_tests(&contract);
assert!(code.contains("fn falsify_sm_001()"));
assert!(code.contains("fn falsify_sm_002()"));
assert!(code.contains("todo!"));
}
#[test]
fn generate_tests_includes_predictions() {
let contract = sample_contract();
let code = generate_contract_tests(&contract);
assert!(code.contains("sum(output) ≈ 1.0"));
assert!(code.contains("missing max subtraction"));
}
#[test]
fn generate_trait_includes_paper_refs() {
let contract = sample_contract();
let code = generate_trait(&contract);
assert!(code.contains("Paper: Paper (2024)"));
}
#[test]
fn generate_trait_includes_domain_codomain() {
let contract = sample_contract();
let code = generate_trait(&contract);
assert!(code.contains("Domain:"));
assert!(code.contains("Codomain:"));
}
#[test]
fn generate_trait_includes_proof_obligation() {
let contract = sample_contract();
let code = generate_trait(&contract);
assert!(code.contains("INVARIANT"));
assert!(code.contains("normalization"));
}
#[test]
fn stem_to_trait_name_basic() {
assert_eq!(stem_to_trait_name("softmax-kernel-v1"), "SoftmaxKernelV1");
assert_eq!(stem_to_trait_name("gelu-kernel-v1"), "GeluKernelV1");
assert_eq!(stem_to_trait_name("a"), "A");
assert_eq!(stem_to_trait_name(""), "");
}
#[test]
fn generate_standalone_trait_header() {
let contract = sample_contract();
let code = generate_standalone_trait(&contract, "softmax-kernel-v1");
assert!(code.contains("pub trait SoftmaxKernelV1"));
assert!(code.contains("Auto-generated contract trait"));
assert!(code.contains("DO NOT EDIT"));
assert!(code.contains("#![allow(clippy::doc_markdown)]"));
}
#[test]
fn generate_standalone_trait_methods() {
let contract = sample_contract();
let code = generate_standalone_trait(&contract, "softmax-kernel-v1");
assert!(code.contains("fn softmax("));
assert!(code.contains("-> Vec<f32>"));
}
#[test]
fn generate_standalone_trait_invariants() {
let contract = sample_contract();
let code = generate_standalone_trait(&contract, "softmax-kernel-v1");
assert!(code.contains("Invariant: sum(output) = 1.0"));
}
#[test]
fn generate_standalone_trait_references() {
let contract = sample_contract();
let code = generate_standalone_trait(&contract, "softmax-kernel-v1");
assert!(code.contains("Reference: Paper (2024)"));
}
#[test]
fn generate_standalone_trait_implementor_note() {
let contract = sample_contract();
let code = generate_standalone_trait(&contract, "test-v1");
assert!(code.contains("Implementors must provide all 1 equation(s)"));
assert!(code.contains("Missing method = compile error"));
}
#[test]
fn generate_contract_tests_all_ids() {
let contract = sample_contract();
let code = generate_contract_tests(&contract);
assert!(code.contains("#[cfg(test)]"));
assert!(code.contains("mod contract_tests"));
assert!(code.contains("use super::*;"));
assert!(code.contains("fn falsify_sm_001()"));
assert!(code.contains("fn falsify_sm_002()"));
}
fn multi_equation_contract() -> Contract {
parse_contract_str(
r#"
metadata:
version: "2.0.0"
description: "Multi-equation kernel"
references:
- "Ref A"
- "Ref B"
equations:
alpha:
formula: "alpha(x) = x^2"
domain: "x ∈ ℝ^n"
codomain: "ℝ^n"
invariants:
- "output >= 0"
beta:
formula: "beta(x) = 2x"
domain: "x ∈ ℝ^n"
invariants:
- "output proportional to input"
proof_obligations:
- type: bound
property: "non-negativity"
formal: "∀x: alpha(x) ≥ 0"
falsification_tests:
- id: FALSIFY-MQ-001
rule: "non-neg"
prediction: "alpha >= 0"
if_fails: "squared value is negative"
"#,
)
.unwrap()
}
#[test]
fn generate_trait_multiple_equations() {
let contract = multi_equation_contract();
let code = generate_trait(&contract);
assert!(code.contains("fn alpha("));
assert!(code.contains("fn beta("));
assert!(code.contains("BOUND"));
}
#[test]
fn generate_standalone_multiple_equations() {
let contract = multi_equation_contract();
let code = generate_standalone_trait(&contract, "multi-eq-v1");
assert!(code.contains("pub trait MultiEqV1"));
assert!(code.contains("fn alpha("));
assert!(code.contains("fn beta("));
assert!(code.contains("2 equation(s)"));
}
#[test]
fn generate_trait_version_in_header() {
let contract = sample_contract();
let code = generate_trait(&contract);
assert!(code.contains("v1.0.0"));
}
}