python_ast/parser/
mod.rs

1use crate::{dump, Module, Name, *};
2
3use pyo3::prelude::*;
4use std::ffi::CString;
5
6use std::path::MAIN_SEPARATOR;
7
8/// Takes a string of Python code and emits a Python struct that represents the AST.
9fn parse_to_py(
10    input: impl AsRef<str>,
11    filename: impl AsRef<str>,
12    py: Python<'_>,
13) -> PyResult<PyObject> {
14    let pymodule_code = include_str!("__init__.py");
15
16    // We want to call tokenize.tokenize from Python.
17    let code_cstr = CString::new(pymodule_code)?;
18    let pymodule = PyModule::from_code(py, &code_cstr, c"__init__.py", c"parser")?;
19    let t = pymodule.getattr("parse")?;
20    assert!(t.is_callable());
21    let args = (input.as_ref(), filename.as_ref());
22
23    let py_tree = t.call1(args)?;
24    log::debug!("py_tree: {}", dump(&py_tree, Some(4))?);
25
26    Ok(py_tree.into())
27}
28
29/// Parses Python code and returns the AST as a Module.
30/// 
31/// This function accepts any type that can be converted to a string reference,
32/// making it flexible for different input types.
33/// 
34/// # Arguments
35/// * `input` - The Python source code to parse
36/// * `filename` - The filename to associate with the parsed code
37/// 
38/// # Returns
39/// * `PyResult<Module>` - The parsed AST module or a Python error
40/// 
41/// # Examples
42/// ```rust
43/// use python_ast::parse;
44/// 
45/// let code = "x = 1 + 2";
46/// let module = parse(code, "example.py").unwrap();
47/// ```
48pub fn parse(input: impl AsRef<str>, filename: impl AsRef<str>) -> PyResult<Module> {
49    let filename = filename.as_ref();
50    let mut module: Module = Python::with_gil(|py| {
51        let py_tree = parse_to_py(input, filename, py)?;
52        py_tree.extract(py)
53    })?;
54    module.filename = Some(filename.into());
55
56    if let Some(name_str) = filename.replace(MAIN_SEPARATOR, "__").strip_suffix(".py") {
57        module.name =
58            Some(Name::try_from(name_str).unwrap_or_else(|_| panic!("Invalid name {}", name_str)));
59    }
60
61    println!("module: {:#?}", module);
62    for item in module.__dir__() {
63        println!("module.__dir__: {:#?}", item.as_ref());
64    }
65    Ok(module)
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71
72    #[test]
73    fn test_parse_simple_expression() {
74        let code = "1 + 2";
75        let result = parse(code, "test.py");
76        assert!(result.is_ok());
77        
78        let module = result.unwrap();
79        assert!(module.filename.is_some());
80        assert_eq!(module.filename.as_ref().unwrap(), "test.py");
81        assert!(!module.raw.body.is_empty());
82    }
83
84    #[test]
85    fn test_parse_function_definition() {
86        let code = r#"
87def hello_world():
88    return "Hello, World!"
89"#;
90        let result = parse(code, "function_test.py");
91        assert!(result.is_ok());
92        
93        let module = result.unwrap();
94        assert_eq!(module.raw.body.len(), 1);
95    }
96
97    #[test]
98    fn test_parse_class_definition() {
99        let code = r#"
100class TestClass:
101    def __init__(self):
102        self.value = 42
103        
104    def get_value(self):
105        return self.value
106"#;
107        let result = parse(code, "class_test.py");
108        assert!(result.is_ok());
109        
110        let module = result.unwrap();
111        assert_eq!(module.raw.body.len(), 1);
112    }
113
114    #[test]
115    fn test_parse_import_statements() {
116        let code = r#"
117import os
118import sys
119from collections import defaultdict
120from typing import List, Dict
121"#;
122        let result = parse(code, "import_test.py");
123        assert!(result.is_ok());
124        
125        let module = result.unwrap();
126        assert_eq!(module.raw.body.len(), 4);
127    }
128
129    #[test]
130    fn test_parse_complex_expressions() {
131        let code = r#"
132result = [x**2 for x in range(10) if x % 2 == 0]
133data = {"key": value for key, value in items.items()}
134condition = (a > b) and (c < d) or (e == f)
135"#;
136        let result = parse(code, "expressions_test.py");
137        assert!(result.is_ok());
138        
139        let module = result.unwrap();
140        assert_eq!(module.raw.body.len(), 3);
141    }
142
143    #[test]
144    fn test_parse_control_flow() {
145        let code = r#"
146if condition:
147    for i in range(10):
148        if i % 2 == 0:
149            continue
150        else:
151            break
152else:
153    while True:
154        try:
155            do_something()
156        except Exception as e:
157            handle_error(e)
158        finally:
159            cleanup()
160"#;
161        let result = parse(code, "control_flow_test.py");
162        assert!(result.is_ok());
163        
164        let module = result.unwrap();
165        assert_eq!(module.raw.body.len(), 1);
166    }
167
168    #[test]
169    fn test_parse_async_code() {
170        let code = r#"
171async def async_function():
172    async with async_context():
173        result = await some_async_operation()
174        async for item in async_iterator:
175            yield item
176"#;
177        let result = parse(code, "async_test.py");
178        assert!(result.is_ok());
179        
180        let module = result.unwrap();
181        assert_eq!(module.raw.body.len(), 1);
182    }
183
184    #[test]
185    fn test_parse_decorators() {
186        let code = r#"
187@decorator
188@another_decorator(arg1, arg2)
189def decorated_function():
190    pass
191
192@property
193def getter(self):
194    return self._value
195"#;
196        let result = parse(code, "decorators_test.py");
197        assert!(result.is_ok());
198        
199        let module = result.unwrap();
200        assert_eq!(module.raw.body.len(), 2);
201    }
202
203    #[test]
204    fn test_parse_invalid_syntax() {
205        let code = "def invalid_function(";  // Missing closing parenthesis
206        let result = parse(code, "invalid.py");
207        assert!(result.is_err());
208    }
209
210    #[test]
211    fn test_parse_empty_file() {
212        let code = "";
213        let result = parse(code, "empty.py");
214        assert!(result.is_ok());
215        
216        let module = result.unwrap();
217        assert!(module.raw.body.is_empty());
218    }
219
220    #[test]
221    fn test_parse_comments_and_docstrings() {
222        let code = r#"
223"""Module docstring"""
224# This is a comment
225def function_with_docstring():
226    """Function docstring"""
227    pass  # Another comment
228"#;
229        let result = parse(code, "comments_test.py");
230        assert!(result.is_ok());
231        
232        let module = result.unwrap();
233        assert_eq!(module.raw.body.len(), 2); // Docstring + function
234    }
235
236    #[test]
237    fn test_module_name_generation() {
238        let result = parse("x = 1", "some_file.py");
239        assert!(result.is_ok());
240        
241        let module = result.unwrap();
242        assert!(module.name.is_some());
243        assert_eq!(module.name.unwrap().id, "some_file");
244    }
245
246    #[test]
247    fn test_module_name_with_path_separators() {
248        let code = "x = 1";
249        let filename = format!("path{}to{}module.py", std::path::MAIN_SEPARATOR, std::path::MAIN_SEPARATOR);
250        let result = parse(code, &filename);
251        assert!(result.is_ok());
252        
253        let module = result.unwrap();
254        assert!(module.name.is_some());
255        assert_eq!(module.name.unwrap().id, "path__to__module");
256    }
257}