1use anyhow::{Context, Result};
6use pyo3::prelude::*;
7use pyo3::types::PyModule;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct PythonAST {
13 pub node_type: String,
15 pub lineno: Option<usize>,
17 pub col_offset: Option<usize>,
19 pub children: Vec<PythonAST>,
21 pub attributes: std::collections::HashMap<String, String>,
23}
24
25impl PythonAST {
26 #[must_use]
28 pub fn new(node_type: String) -> Self {
29 Self {
30 node_type,
31 lineno: None,
32 col_offset: None,
33 children: Vec::new(),
34 attributes: std::collections::HashMap::new(),
35 }
36 }
37}
38
39pub fn parse(source: &str, filename: &str) -> Result<PythonAST> {
45 Python::with_gil(|py| parse_with_python(py, source, filename))
46}
47
48fn parse_with_python(py: Python<'_>, source: &str, filename: &str) -> Result<PythonAST> {
50 let ast_module =
52 PyModule::import_bound(py, "ast").context("Failed to import Python ast module")?;
53
54 let ast_obj = ast_module
56 .call_method1("parse", (source, filename))
57 .context("Failed to parse Python source code")?;
58
59 extract_ast_node(&ast_obj)
61}
62
63fn extract_ast_node(obj: &Bound<'_, PyAny>) -> Result<PythonAST> {
65 let node_type = obj
66 .getattr("__class__")?
67 .getattr("__name__")?
68 .extract::<String>()?;
69
70 let mut ast = PythonAST::new(node_type.clone());
71
72 extract_location_info(obj, &mut ast);
74
75 extract_node_attributes(obj, &node_type, &mut ast)?;
77
78 Ok(ast)
79}
80
81fn extract_location_info(obj: &Bound<'_, PyAny>, ast: &mut PythonAST) {
83 if let Ok(lineno) = obj.getattr("lineno") {
84 ast.lineno = lineno.extract().ok();
85 }
86 if let Ok(col_offset) = obj.getattr("col_offset") {
87 ast.col_offset = col_offset.extract().ok();
88 }
89}
90
91fn extract_node_attributes(
93 obj: &Bound<'_, PyAny>,
94 node_type: &str,
95 ast: &mut PythonAST,
96) -> Result<()> {
97 match node_type {
98 "Module" => extract_module_attrs(obj, ast)?,
99 "FunctionDef" => extract_function_def_attrs(obj, ast)?,
100 "Return" => extract_return_attrs(obj, ast)?,
101 "Call" => extract_call_attrs(obj, ast)?,
102 "Name" => extract_name_attrs(obj, ast)?,
103 _ => extract_default_attrs(obj, ast)?,
104 }
105 Ok(())
106}
107
108fn extract_module_attrs(obj: &Bound<'_, PyAny>, ast: &mut PythonAST) -> Result<()> {
110 if let Ok(body) = obj.getattr("body") {
111 ast.children = extract_list(&body)?;
112 }
113 Ok(())
114}
115
116fn extract_function_def_attrs(obj: &Bound<'_, PyAny>, ast: &mut PythonAST) -> Result<()> {
118 if let Ok(name) = obj.getattr("name") {
119 ast.attributes.insert("name".to_string(), name.extract()?);
120 }
121 if let Ok(body) = obj.getattr("body") {
122 ast.children = extract_list(&body)?;
123 }
124 Ok(())
125}
126
127fn extract_return_attrs(obj: &Bound<'_, PyAny>, ast: &mut PythonAST) -> Result<()> {
129 if let Ok(value) = obj.getattr("value") {
130 if !value.is_none() {
131 ast.children.push(extract_ast_node(&value)?);
132 }
133 }
134 Ok(())
135}
136
137fn extract_call_attrs(obj: &Bound<'_, PyAny>, ast: &mut PythonAST) -> Result<()> {
139 if let Ok(func) = obj.getattr("func") {
140 ast.children.push(extract_ast_node(&func)?);
141 }
142 if let Ok(args) = obj.getattr("args") {
143 ast.children.extend(extract_list(&args)?);
144 }
145 Ok(())
146}
147
148fn extract_name_attrs(obj: &Bound<'_, PyAny>, ast: &mut PythonAST) -> Result<()> {
150 if let Ok(id) = obj.getattr("id") {
151 ast.attributes.insert("id".to_string(), id.extract()?);
152 }
153 Ok(())
154}
155
156#[allow(clippy::unnecessary_wraps)]
158fn extract_default_attrs(obj: &Bound<'_, PyAny>, ast: &mut PythonAST) -> Result<()> {
159 if let Ok(value) = obj.getattr("value") {
160 if !value.is_none() {
161 if let Ok(child) = extract_ast_node(&value) {
162 ast.children.push(child);
163 }
164 }
165 }
166 Ok(())
167}
168
169fn extract_list(list: &Bound<'_, PyAny>) -> Result<Vec<PythonAST>> {
171 let mut nodes = Vec::new();
172 for item in list.iter()? {
173 let item = item?;
174 nodes.push(extract_ast_node(&item)?);
175 }
176 Ok(nodes)
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_parse_simple_function() {
185 let source = r"
186def my_len(x):
187 return len(x)
188";
189 let ast = parse(source, "test.py").unwrap();
190 assert_eq!(ast.node_type, "Module");
191 assert!(!ast.children.is_empty());
192 }
193
194 #[test]
195 fn test_parse_with_type_hints() {
196 let source = r"
197def my_len(x: list) -> int:
198 return len(x)
199";
200 let ast = parse(source, "test.py").unwrap();
201 assert_eq!(ast.node_type, "Module");
202 }
203
204 #[test]
205 fn test_parse_invalid_syntax() {
206 let source = "def invalid syntax here";
207 let result = parse(source, "test.py");
208 assert!(result.is_err());
209 }
210}