1use std::path::Path;
7
8#[derive(Debug, Clone)]
10pub struct ExtractedEquation {
11 pub name: String,
12 pub formula: String,
13 pub preconditions: Vec<String>,
14 pub postconditions: Vec<String>,
15 pub source_file: String,
16 pub source_line: usize,
17}
18
19#[derive(Debug, Clone)]
21pub struct ExtractedKernel {
22 pub function_name: String,
23 pub module_path: String,
24 pub docstring: String,
25 pub equations: Vec<ExtractedEquation>,
26 pub arguments: Vec<(String, String)>, pub return_type: String,
28}
29
30pub fn extract_from_pytorch(target: &str) -> Result<ExtractedKernel, String> {
35 let (file_path, fn_name) = if target.contains("::") {
36 let parts: Vec<&str> = target.splitn(2, "::").collect();
37 (parts[0], Some(parts[1]))
38 } else {
39 (target, None)
40 };
41
42 let content = std::fs::read_to_string(file_path)
43 .map_err(|e| format!("Failed to read {file_path}: {e}"))?;
44
45 let fn_name = fn_name.unwrap_or_else(|| {
46 Path::new(file_path)
48 .file_stem()
49 .and_then(|s| s.to_str())
50 .unwrap_or("unknown")
51 });
52
53 extract_function(&content, fn_name, file_path)
54}
55
56fn extract_function(
57 content: &str,
58 fn_name: &str,
59 file_path: &str,
60) -> Result<ExtractedKernel, String> {
61 let lines: Vec<&str> = content.lines().collect();
62
63 let def_pattern = format!("def {fn_name}(");
65 let def_line = lines
66 .iter()
67 .enumerate()
68 .find(|(_, line)| line.trim().starts_with(&def_pattern))
69 .map(|(i, _)| i)
70 .ok_or_else(|| format!("Function `{fn_name}` not found in {file_path}"))?;
71
72 let args = extract_arguments(&lines, def_line);
74
75 let docstring = extract_docstring(&lines, def_line);
77
78 let equations = extract_equations_from_docstring(&docstring, fn_name, file_path, def_line);
80
81 let return_type = extract_return_type(&lines, def_line);
83
84 Ok(ExtractedKernel {
85 function_name: fn_name.to_string(),
86 module_path: file_path.to_string(),
87 docstring,
88 equations,
89 arguments: args,
90 return_type,
91 })
92}
93
94fn extract_arguments(lines: &[&str], def_line: usize) -> Vec<(String, String)> {
95 let mut args = Vec::new();
96 let mut i = def_line;
97 let mut in_def = true;
98
99 while i < lines.len() && in_def {
100 let line = lines[i].trim();
101 for part in line.split(',') {
103 let part = part
104 .trim()
105 .trim_start_matches("def ")
106 .trim_start_matches('(');
107 if let Some(colon) = part.find(':') {
108 let name = part[..colon].trim().to_string();
109 let typ = part[colon + 1..]
110 .trim()
111 .trim_end_matches(')')
112 .trim_end_matches(',')
113 .to_string();
114 if !name.is_empty() && name != "self" && !name.starts_with('_') {
115 args.push((name, typ));
116 }
117 }
118 }
119 if line.contains("):") || line.ends_with("):") || line.ends_with(") ->") {
120 in_def = false;
121 }
122 i += 1;
123 }
124 args
125}
126
127fn extract_docstring(lines: &[&str], def_line: usize) -> String {
128 let mut doc = String::new();
129 let mut i = def_line + 1;
130 let mut in_docstring = false;
131
132 while i < lines.len() {
133 let trimmed = lines[i].trim();
134 if in_docstring {
135 if trimmed.contains("\"\"\"") {
136 let before = trimmed.trim_end_matches("\"\"\"");
137 doc.push_str(before);
138 break;
139 }
140 doc.push_str(trimmed);
141 doc.push('\n');
142 } else if trimmed.starts_with("r\"\"\"") || trimmed.starts_with("\"\"\"") {
143 in_docstring = true;
144 let after = trimmed
145 .trim_start_matches("r\"\"\"")
146 .trim_start_matches("\"\"\"");
147 if after.ends_with("\"\"\"") {
148 doc.push_str(after.trim_end_matches("\"\"\""));
149 break;
150 }
151 doc.push_str(after);
152 doc.push('\n');
153 }
154 i += 1;
155 }
156 doc
157}
158
159fn extract_equations_from_docstring(
160 docstring: &str,
161 fn_name: &str,
162 file_path: &str,
163 line: usize,
164) -> Vec<ExtractedEquation> {
165 let mut equations = Vec::new();
166
167 let mut pos = 0;
169 while let Some(start) = docstring[pos..].find(":math:`") {
170 let abs_start = pos + start + 7; if let Some(end) = docstring[abs_start..].find('`') {
172 let formula = &docstring[abs_start..abs_start + end];
173
174 let readable = latex_to_readable(formula);
176
177 let preconditions = infer_preconditions(docstring, fn_name);
179
180 let postconditions = infer_postconditions(docstring, fn_name);
182
183 equations.push(ExtractedEquation {
184 name: fn_name.to_string(),
185 formula: readable,
186 preconditions,
187 postconditions,
188 source_file: file_path.to_string(),
189 source_line: line,
190 });
191
192 pos = abs_start + end + 1;
193 } else {
194 break;
195 }
196 }
197
198 if equations.is_empty() {
199 equations.push(ExtractedEquation {
201 name: fn_name.to_string(),
202 formula: format!("{fn_name}(input) → output"),
203 preconditions: vec!["!input.is_empty()".to_string()],
204 postconditions: vec!["ret.iter().all(|x| x.is_finite())".to_string()],
205 source_file: file_path.to_string(),
206 source_line: line,
207 });
208 }
209
210 equations
211}
212
213fn extract_return_type(lines: &[&str], def_line: usize) -> String {
214 for line in lines.iter().skip(def_line).take(5) {
215 if let Some(arrow) = line.find("->") {
216 let ret = line[arrow + 2..].trim().trim_end_matches(':').trim();
217 return ret.to_string();
218 }
219 }
220 "Tensor".to_string()
221}
222
223fn latex_to_readable(latex: &str) -> String {
224 latex
225 .replace("\\text{", "")
226 .replace("\\frac{", "(")
227 .replace("}{", ") / (")
228 .replace("\\exp", "exp")
229 .replace("\\sum", "Σ")
230 .replace("\\log", "log")
231 .replace("\\max", "max")
232 .replace("\\sqrt", "√")
233 .replace("\\sigma", "σ")
234 .replace("\\mu", "μ")
235 .replace("\\epsilon", "ε")
236 .replace('}', ")")
237 .replace('{', "(")
238 .replace("_((", "_(")
239}
240
241fn infer_preconditions(docstring: &str, _fn_name: &str) -> Vec<String> {
242 let mut pres = vec!["!input.is_empty()".to_string()];
243
244 if docstring.contains("dim") {
245 pres.push("dim < input.ndim()".to_string());
246 }
247 if docstring.contains("positive") || docstring.contains("> 0") {
248 pres.push("input.iter().all(|x| *x > 0.0)".to_string());
249 }
250
251 pres
252}
253
254fn infer_postconditions(docstring: &str, _fn_name: &str) -> Vec<String> {
255 let mut posts = Vec::new();
256
257 if docstring.contains("[0, 1]") || docstring.contains("range `[0, 1]`") {
258 posts.push("ret.iter().all(|&v| v >= 0.0 && v <= 1.0)".to_string());
259 }
260 if docstring.contains("sum to 1") || docstring.contains("sum to one") {
261 posts.push("(ret.iter().sum::<f32>() - 1.0).abs() < 1e-6".to_string());
262 }
263 if docstring.contains("normalized") || docstring.contains("unit") {
264 posts.push("ret.iter().all(|x| x.is_finite())".to_string());
265 }
266
267 if posts.is_empty() {
268 posts.push("ret.iter().all(|x| x.is_finite())".to_string());
269 }
270
271 posts
272}
273
274pub fn kernel_to_yaml(kernel: &ExtractedKernel) -> String {
276 let mut yaml = String::new();
277
278 yaml.push_str(&format!("# Auto-extracted from {}\n", kernel.module_path));
279 yaml.push_str(&format!("# Function: {}\n\n", kernel.function_name));
280
281 yaml.push_str("metadata:\n");
282 yaml.push_str(" version: \"1.0.0\"\n");
283 yaml.push_str(" created: \"2026-03-21\"\n");
284 yaml.push_str(" author: \"pv extract-pytorch\"\n");
285 yaml.push_str(&format!(
286 " description: \"Contract for {} extracted from PyTorch\"\n",
287 kernel.function_name
288 ));
289 yaml.push_str(" references:\n");
290 yaml.push_str(&format!(" - \"{}\"\n\n", kernel.module_path));
291
292 yaml.push_str("equations:\n");
293 for eq in &kernel.equations {
294 yaml.push_str(&format!(" {}:\n", eq.name));
295 yaml.push_str(&format!(
296 " formula: \"{}\"\n",
297 eq.formula.replace('"', "'")
298 ));
299 if !eq.preconditions.is_empty() {
300 yaml.push_str(" preconditions:\n");
301 for pre in &eq.preconditions {
302 yaml.push_str(&format!(" - \"{pre}\"\n"));
303 }
304 }
305 if !eq.postconditions.is_empty() {
306 yaml.push_str(" postconditions:\n");
307 for post in &eq.postconditions {
308 yaml.push_str(&format!(" - \"{post}\"\n"));
309 }
310 }
311 yaml.push_str(&format!(
312 " lean_theorem: \"ProvableContracts.Theorems.{}.Correctness\"\n\n",
313 capitalize(&eq.name)
314 ));
315 }
316
317 yaml.push_str("falsification_tests:\n");
318 yaml.push_str(&format!(
319 " - id: FALSIFY-{}-001\n",
320 kernel.function_name.to_uppercase()
321 ));
322 yaml.push_str(&format!(
323 " rule: \"{} correctness\"\n",
324 kernel.function_name
325 ));
326 yaml.push_str(&format!(
327 " test: \"test_{}_basic\"\n",
328 kernel.function_name
329 ));
330 yaml.push_str(&format!(
331 " prediction: \"{} output matches PyTorch reference\"\n",
332 kernel.function_name
333 ));
334 yaml.push_str(&format!(
335 " if_fails: \"{} implementation diverges from PyTorch\"\n",
336 kernel.function_name
337 ));
338
339 yaml
340}
341
342fn capitalize(s: &str) -> String {
343 let mut c = s.chars();
344 match c.next() {
345 None => String::new(),
346 Some(f) => f.to_uppercase().to_string() + c.as_str(),
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_latex_to_readable() {
356 assert_eq!(
357 latex_to_readable("\\frac{\\exp(x_i)}{\\sum_j \\exp(x_j)}"),
358 "(exp(x_i)) / (Σ_j exp(x_j))"
359 );
360 }
361
362 #[test]
363 fn test_extract_softmax() {
364 let pytorch_path = "/home/noah/src/pytorch/torch/nn/functional.py";
365 if std::path::Path::new(pytorch_path).exists() {
366 let kernel = extract_from_pytorch(&format!("{pytorch_path}::softmax")).unwrap();
367 assert_eq!(kernel.function_name, "softmax");
368 assert!(!kernel.equations.is_empty());
369 assert!(kernel.equations[0].formula.contains("exp"));
370 }
371 }
372}
373
374#[cfg(test)]
375#[path = "extract_tests.rs"]
376mod extract_tests;