Skip to main content

provable_contracts/
extract.rs

1//! `PyTorch` kernel extraction — reads Python source, extracts equations.
2//!
3//! Parses docstrings for LaTeX math, extracts preconditions from type hints
4//! and assertions, and generates YAML contract skeletons.
5
6use std::path::Path;
7
8/// Extracted equation from `PyTorch` source.
9#[derive(Debug, Clone)]
10pub struct ExtractedEquation {
11    pub name: String,
12    pub formula: String,
13    pub preconditions: Vec<String>,
14    pub postconditions: Vec<String>,
15    pub source_file: String,
16    pub source_line: usize,
17}
18
19/// Extracted kernel from a `PyTorch` source file.
20#[derive(Debug, Clone)]
21pub struct ExtractedKernel {
22    pub function_name: String,
23    pub module_path: String,
24    pub docstring: String,
25    pub equations: Vec<ExtractedEquation>,
26    pub arguments: Vec<(String, String)>, // (name, type)
27    pub return_type: String,
28}
29
30/// Extract a kernel from a `PyTorch` Python source file.
31///
32/// Parses the function definition, docstring, and LaTeX math.
33/// `target` is either a file path or `file.py::function_name`.
34pub fn extract_from_pytorch(target: &str) -> Result<ExtractedKernel, String> {
35    let (file_path, fn_name) = if target.contains("::") {
36        let parts: Vec<&str> = target.splitn(2, "::").collect();
37        (parts[0], Some(parts[1]))
38    } else {
39        (target, None)
40    };
41
42    let content = std::fs::read_to_string(file_path)
43        .map_err(|e| format!("Failed to read {file_path}: {e}"))?;
44
45    let fn_name = fn_name.unwrap_or_else(|| {
46        // Guess from filename
47        Path::new(file_path)
48            .file_stem()
49            .and_then(|s| s.to_str())
50            .unwrap_or("unknown")
51    });
52
53    extract_function(&content, fn_name, file_path)
54}
55
56fn extract_function(
57    content: &str,
58    fn_name: &str,
59    file_path: &str,
60) -> Result<ExtractedKernel, String> {
61    let lines: Vec<&str> = content.lines().collect();
62
63    // Find the function definition
64    let def_pattern = format!("def {fn_name}(");
65    let def_line = lines
66        .iter()
67        .enumerate()
68        .find(|(_, line)| line.trim().starts_with(&def_pattern))
69        .map(|(i, _)| i)
70        .ok_or_else(|| format!("Function `{fn_name}` not found in {file_path}"))?;
71
72    // Extract arguments from the def line
73    let args = extract_arguments(&lines, def_line);
74
75    // Extract docstring
76    let docstring = extract_docstring(&lines, def_line);
77
78    // Extract equations from LaTeX in docstring
79    let equations = extract_equations_from_docstring(&docstring, fn_name, file_path, def_line);
80
81    // Infer return type from type hints
82    let return_type = extract_return_type(&lines, def_line);
83
84    Ok(ExtractedKernel {
85        function_name: fn_name.to_string(),
86        module_path: file_path.to_string(),
87        docstring,
88        equations,
89        arguments: args,
90        return_type,
91    })
92}
93
94fn extract_arguments(lines: &[&str], def_line: usize) -> Vec<(String, String)> {
95    let mut args = Vec::new();
96    let mut i = def_line;
97    let mut in_def = true;
98
99    while i < lines.len() && in_def {
100        let line = lines[i].trim();
101        // Parse "name: Type" patterns
102        for part in line.split(',') {
103            let part = part
104                .trim()
105                .trim_start_matches("def ")
106                .trim_start_matches('(');
107            if let Some(colon) = part.find(':') {
108                let name = part[..colon].trim().to_string();
109                let typ = part[colon + 1..]
110                    .trim()
111                    .trim_end_matches(')')
112                    .trim_end_matches(',')
113                    .to_string();
114                if !name.is_empty() && name != "self" && !name.starts_with('_') {
115                    args.push((name, typ));
116                }
117            }
118        }
119        if line.contains("):") || line.ends_with("):") || line.ends_with(") ->") {
120            in_def = false;
121        }
122        i += 1;
123    }
124    args
125}
126
127fn extract_docstring(lines: &[&str], def_line: usize) -> String {
128    let mut doc = String::new();
129    let mut i = def_line + 1;
130    let mut in_docstring = false;
131
132    while i < lines.len() {
133        let trimmed = lines[i].trim();
134        if in_docstring {
135            if trimmed.contains("\"\"\"") {
136                let before = trimmed.trim_end_matches("\"\"\"");
137                doc.push_str(before);
138                break;
139            }
140            doc.push_str(trimmed);
141            doc.push('\n');
142        } else if trimmed.starts_with("r\"\"\"") || trimmed.starts_with("\"\"\"") {
143            in_docstring = true;
144            let after = trimmed
145                .trim_start_matches("r\"\"\"")
146                .trim_start_matches("\"\"\"");
147            if after.ends_with("\"\"\"") {
148                doc.push_str(after.trim_end_matches("\"\"\""));
149                break;
150            }
151            doc.push_str(after);
152            doc.push('\n');
153        }
154        i += 1;
155    }
156    doc
157}
158
159fn extract_equations_from_docstring(
160    docstring: &str,
161    fn_name: &str,
162    file_path: &str,
163    line: usize,
164) -> Vec<ExtractedEquation> {
165    let mut equations = Vec::new();
166
167    // Extract LaTeX math from :math:`...`
168    let mut pos = 0;
169    while let Some(start) = docstring[pos..].find(":math:`") {
170        let abs_start = pos + start + 7; // skip ":math:`"
171        if let Some(end) = docstring[abs_start..].find('`') {
172            let formula = &docstring[abs_start..abs_start + end];
173
174            // Convert LaTeX to readable math
175            let readable = latex_to_readable(formula);
176
177            // Infer preconditions from argument types and docstring
178            let preconditions = infer_preconditions(docstring, fn_name);
179
180            // Infer postconditions from docstring descriptions
181            let postconditions = infer_postconditions(docstring, fn_name);
182
183            equations.push(ExtractedEquation {
184                name: fn_name.to_string(),
185                formula: readable,
186                preconditions,
187                postconditions,
188                source_file: file_path.to_string(),
189                source_line: line,
190            });
191
192            pos = abs_start + end + 1;
193        } else {
194            break;
195        }
196    }
197
198    if equations.is_empty() {
199        // No LaTeX found — create a basic equation from function signature
200        equations.push(ExtractedEquation {
201            name: fn_name.to_string(),
202            formula: format!("{fn_name}(input) → output"),
203            preconditions: vec!["!input.is_empty()".to_string()],
204            postconditions: vec!["ret.iter().all(|x| x.is_finite())".to_string()],
205            source_file: file_path.to_string(),
206            source_line: line,
207        });
208    }
209
210    equations
211}
212
213fn extract_return_type(lines: &[&str], def_line: usize) -> String {
214    for line in lines.iter().skip(def_line).take(5) {
215        if let Some(arrow) = line.find("->") {
216            let ret = line[arrow + 2..].trim().trim_end_matches(':').trim();
217            return ret.to_string();
218        }
219    }
220    "Tensor".to_string()
221}
222
223fn latex_to_readable(latex: &str) -> String {
224    latex
225        .replace("\\text{", "")
226        .replace("\\frac{", "(")
227        .replace("}{", ") / (")
228        .replace("\\exp", "exp")
229        .replace("\\sum", "Σ")
230        .replace("\\log", "log")
231        .replace("\\max", "max")
232        .replace("\\sqrt", "√")
233        .replace("\\sigma", "σ")
234        .replace("\\mu", "μ")
235        .replace("\\epsilon", "ε")
236        .replace('}', ")")
237        .replace('{', "(")
238        .replace("_((", "_(")
239}
240
241fn infer_preconditions(docstring: &str, _fn_name: &str) -> Vec<String> {
242    let mut pres = vec!["!input.is_empty()".to_string()];
243
244    if docstring.contains("dim") {
245        pres.push("dim < input.ndim()".to_string());
246    }
247    if docstring.contains("positive") || docstring.contains("> 0") {
248        pres.push("input.iter().all(|x| *x > 0.0)".to_string());
249    }
250
251    pres
252}
253
254fn infer_postconditions(docstring: &str, _fn_name: &str) -> Vec<String> {
255    let mut posts = Vec::new();
256
257    if docstring.contains("[0, 1]") || docstring.contains("range `[0, 1]`") {
258        posts.push("ret.iter().all(|&v| v >= 0.0 && v <= 1.0)".to_string());
259    }
260    if docstring.contains("sum to 1") || docstring.contains("sum to one") {
261        posts.push("(ret.iter().sum::<f32>() - 1.0).abs() < 1e-6".to_string());
262    }
263    if docstring.contains("normalized") || docstring.contains("unit") {
264        posts.push("ret.iter().all(|x| x.is_finite())".to_string());
265    }
266
267    if posts.is_empty() {
268        posts.push("ret.iter().all(|x| x.is_finite())".to_string());
269    }
270
271    posts
272}
273
274/// Generate YAML contract from extracted kernel.
275pub fn kernel_to_yaml(kernel: &ExtractedKernel) -> String {
276    let mut yaml = String::new();
277
278    yaml.push_str(&format!("# Auto-extracted from {}\n", kernel.module_path));
279    yaml.push_str(&format!("# Function: {}\n\n", kernel.function_name));
280
281    yaml.push_str("metadata:\n");
282    yaml.push_str("  version: \"1.0.0\"\n");
283    yaml.push_str("  created: \"2026-03-21\"\n");
284    yaml.push_str("  author: \"pv extract-pytorch\"\n");
285    yaml.push_str(&format!(
286        "  description: \"Contract for {} extracted from PyTorch\"\n",
287        kernel.function_name
288    ));
289    yaml.push_str("  references:\n");
290    yaml.push_str(&format!("    - \"{}\"\n\n", kernel.module_path));
291
292    yaml.push_str("equations:\n");
293    for eq in &kernel.equations {
294        yaml.push_str(&format!("  {}:\n", eq.name));
295        yaml.push_str(&format!(
296            "    formula: \"{}\"\n",
297            eq.formula.replace('"', "'")
298        ));
299        if !eq.preconditions.is_empty() {
300            yaml.push_str("    preconditions:\n");
301            for pre in &eq.preconditions {
302                yaml.push_str(&format!("      - \"{pre}\"\n"));
303            }
304        }
305        if !eq.postconditions.is_empty() {
306            yaml.push_str("    postconditions:\n");
307            for post in &eq.postconditions {
308                yaml.push_str(&format!("      - \"{post}\"\n"));
309            }
310        }
311        yaml.push_str(&format!(
312            "    lean_theorem: \"ProvableContracts.Theorems.{}.Correctness\"\n\n",
313            capitalize(&eq.name)
314        ));
315    }
316
317    yaml.push_str("falsification_tests:\n");
318    yaml.push_str(&format!(
319        "  - id: FALSIFY-{}-001\n",
320        kernel.function_name.to_uppercase()
321    ));
322    yaml.push_str(&format!(
323        "    rule: \"{} correctness\"\n",
324        kernel.function_name
325    ));
326    yaml.push_str(&format!(
327        "    test: \"test_{}_basic\"\n",
328        kernel.function_name
329    ));
330    yaml.push_str(&format!(
331        "    prediction: \"{} output matches PyTorch reference\"\n",
332        kernel.function_name
333    ));
334    yaml.push_str(&format!(
335        "    if_fails: \"{} implementation diverges from PyTorch\"\n",
336        kernel.function_name
337    ));
338
339    yaml
340}
341
342fn capitalize(s: &str) -> String {
343    let mut c = s.chars();
344    match c.next() {
345        None => String::new(),
346        Some(f) => f.to_uppercase().to_string() + c.as_str(),
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn test_latex_to_readable() {
356        assert_eq!(
357            latex_to_readable("\\frac{\\exp(x_i)}{\\sum_j \\exp(x_j)}"),
358            "(exp(x_i)) / (Σ_j exp(x_j))"
359        );
360    }
361
362    #[test]
363    fn test_extract_softmax() {
364        let pytorch_path = "/home/noah/src/pytorch/torch/nn/functional.py";
365        if std::path::Path::new(pytorch_path).exists() {
366            let kernel = extract_from_pytorch(&format!("{pytorch_path}::softmax")).unwrap();
367            assert_eq!(kernel.function_name, "softmax");
368            assert!(!kernel.equations.is_empty());
369            assert!(kernel.equations[0].formula.contains("exp"));
370        }
371    }
372}
373
374#[cfg(test)]
375#[path = "extract_tests.rs"]
376mod extract_tests;