python_ast/parser/
mod.rs

1use crate::{dump, Module, Name, SourceLocation, Error, Result as CrateResult, *};
2
3use pyo3::prelude::*;
4use std::ffi::CString;
5
6use std::path::MAIN_SEPARATOR;
7
8/// Takes a string of Python code and emits a Python struct that represents the AST.
9fn parse_to_py(
10    input: impl AsRef<str>,
11    filename: impl AsRef<str>,
12    py: Python<'_>,
13) -> PyResult<PyObject> {
14    let pymodule_code = include_str!("__init__.py");
15
16    // We want to call tokenize.tokenize from Python.
17    let code_cstr = CString::new(pymodule_code)?;
18    let pymodule = PyModule::from_code(py, &code_cstr, c"__init__.py", c"parser")?;
19    let t = pymodule.getattr("parse")?;
20    assert!(t.is_callable());
21    let args = (input.as_ref(), filename.as_ref());
22
23    let py_tree = t.call1(args)?;
24    log::debug!("py_tree: {}", dump(&py_tree, Some(4))?);
25
26    Ok(py_tree.into())
27}
28
29/// Parses Python code and returns the AST as a Module with improved error handling.
30/// 
31/// This function accepts any type that can be converted to a string reference,
32/// making it flexible for different input types. It provides detailed error information
33/// including file location and helpful guidance when parsing fails.
34/// 
35/// # Arguments
36/// * `input` - The Python source code to parse
37/// * `filename` - The filename to associate with the parsed code
38/// 
39/// # Returns
40/// * `CrateResult<Module>` - The parsed AST module or a detailed error with location info
41/// 
42/// # Examples
43/// ```rust
44/// use python_ast::parse_enhanced;
45/// 
46/// let code = "x = 1 + 2";
47/// let module = parse_enhanced(code, "example.py").unwrap();
48/// ```
49pub fn parse_enhanced(input: impl AsRef<str>, filename: impl AsRef<str>) -> CrateResult<Module> {
50    let filename = filename.as_ref();
51    let input_str = input.as_ref();
52    let location = SourceLocation::new(filename);
53    
54    // Empty files are valid in Python (they create empty modules), so we don't treat them as errors
55    
56    let mut module: Module = Python::with_gil(|py| {
57        let py_tree = parse_to_py(input_str, filename, py)
58            .map_err(|py_err| {
59                // Convert PyO3 errors to our more detailed error format
60                let error_msg = format!("Python parsing failed: {}", py_err);
61                let help_msg = if error_msg.contains("SyntaxError") {
62                    "Check your Python syntax. Common issues include missing colons, incorrect indentation, or unclosed brackets."
63                } else if error_msg.contains("IndentationError") {
64                    "Fix indentation issues. Python requires consistent indentation (use either spaces or tabs, not both)."
65                } else {
66                    "Ensure the input contains valid Python code. Check for syntax errors or unsupported constructs."
67                };
68                
69                Error::parsing_error(location.clone(), error_msg, help_msg)
70            })?;
71            
72        py_tree.extract(py)
73            .map_err(|py_err| {
74                Error::parsing_error(
75                    location.clone(),
76                    format!("Failed to extract AST: {}", py_err),
77                    "The Python code was parsed but could not be converted to our AST format. This may indicate unsupported Python features."
78                )
79            })
80    })?;
81    
82    module.filename = Some(filename.into());
83
84    if let Some(name_str) = filename.replace(MAIN_SEPARATOR, "__").strip_suffix(".py") {
85        module.name = Some(Name::try_from(name_str).map_err(|_| {
86            Error::parsing_error(
87                location,
88                format!("Invalid module name derived from filename: '{}'", name_str),
89                "Use a valid Python identifier for the filename (without special characters except underscores)."
90            )
91        })?);
92    }
93
94    log::debug!("module: {:#?}", module);
95    for item in module.__dir__() {
96        log::debug!("module.__dir__: {:#?}", item.as_ref());
97    }
98    Ok(module)
99}
100
101/// Parses Python code and returns the AST as a Module (backward compatible version).
102/// 
103/// This is the original parse function that returns PyResult for backward compatibility.
104/// For better error messages with location information, use `parse_enhanced` instead.
105/// 
106/// # Arguments
107/// * `input` - The Python source code to parse
108/// * `filename` - The filename to associate with the parsed code
109/// 
110/// # Returns
111/// * `PyResult<Module>` - The parsed AST module or a PyO3 error
112/// 
113/// # Examples
114/// ```rust
115/// use python_ast::parse;
116/// 
117/// let code = "x = 1 + 2";
118/// let module = parse(code, "example.py").unwrap();
119/// ```
120pub fn parse(input: impl AsRef<str>, filename: impl AsRef<str>) -> PyResult<Module> {
121    // Use the enhanced version but convert the error type for backward compatibility
122    parse_enhanced(input, filename).map_err(|e| e.into())
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn test_parse_simple_expression() {
131        let code = "1 + 2";
132        let result = parse(code, "test.py");
133        assert!(result.is_ok());
134        
135        let module = result.unwrap();
136        assert!(module.filename.is_some());
137        assert_eq!(module.filename.as_ref().unwrap(), "test.py");
138        assert!(!module.raw.body.is_empty());
139    }
140
141    #[test]
142    fn test_parse_function_definition() {
143        let code = r#"
144def hello_world():
145    return "Hello, World!"
146"#;
147        let result = parse(code, "function_test.py");
148        assert!(result.is_ok());
149        
150        let module = result.unwrap();
151        assert_eq!(module.raw.body.len(), 1);
152    }
153
154    #[test]
155    fn test_parse_class_definition() {
156        let code = r#"
157class TestClass:
158    def __init__(self):
159        self.value = 42
160        
161    def get_value(self):
162        return self.value
163"#;
164        let result = parse(code, "class_test.py");
165        assert!(result.is_ok());
166        
167        let module = result.unwrap();
168        assert_eq!(module.raw.body.len(), 1);
169    }
170
171    #[test]
172    fn test_parse_import_statements() {
173        let code = r#"
174import os
175import sys
176from collections import defaultdict
177from typing import List, Dict
178"#;
179        let result = parse(code, "import_test.py");
180        assert!(result.is_ok());
181        
182        let module = result.unwrap();
183        assert_eq!(module.raw.body.len(), 4);
184    }
185
186    #[test]
187    fn test_parse_complex_expressions() {
188        let code = r#"
189result = [x**2 for x in range(10) if x % 2 == 0]
190data = {"key": value for key, value in items.items()}
191condition = (a > b) and (c < d) or (e == f)
192"#;
193        let result = parse(code, "expressions_test.py");
194        assert!(result.is_ok());
195        
196        let module = result.unwrap();
197        assert_eq!(module.raw.body.len(), 3);
198    }
199
200    #[test]
201    fn test_parse_control_flow() {
202        let code = r#"
203if condition:
204    for i in range(10):
205        if i % 2 == 0:
206            continue
207        else:
208            break
209else:
210    while True:
211        try:
212            do_something()
213        except Exception as e:
214            handle_error(e)
215        finally:
216            cleanup()
217"#;
218        let result = parse(code, "control_flow_test.py");
219        assert!(result.is_ok());
220        
221        let module = result.unwrap();
222        assert_eq!(module.raw.body.len(), 1);
223    }
224
225    #[test]
226    fn test_parse_async_code() {
227        let code = r#"
228async def async_function():
229    async with async_context():
230        result = await some_async_operation()
231        async for item in async_iterator:
232            yield item
233"#;
234        let result = parse(code, "async_test.py");
235        assert!(result.is_ok());
236        
237        let module = result.unwrap();
238        assert_eq!(module.raw.body.len(), 1);
239    }
240
241    #[test]
242    fn test_parse_decorators() {
243        let code = r#"
244@decorator
245@another_decorator(arg1, arg2)
246def decorated_function():
247    pass
248
249@property
250def getter(self):
251    return self._value
252"#;
253        let result = parse(code, "decorators_test.py");
254        assert!(result.is_ok());
255        
256        let module = result.unwrap();
257        assert_eq!(module.raw.body.len(), 2);
258    }
259
260    #[test]
261    fn test_parse_invalid_syntax() {
262        let code = "def invalid_function(";  // Missing closing parenthesis
263        let result = parse(code, "invalid.py");
264        assert!(result.is_err());
265    }
266
267    #[test]
268    fn test_parse_empty_file() {
269        let code = "";
270        let result = parse(code, "empty.py");
271        assert!(result.is_ok());
272        
273        let module = result.unwrap();
274        assert!(module.raw.body.is_empty());
275    }
276
277    #[test]
278    fn test_parse_comments_and_docstrings() {
279        let code = r#"
280"""Module docstring"""
281# This is a comment
282def function_with_docstring():
283    """Function docstring"""
284    pass  # Another comment
285"#;
286        let result = parse(code, "comments_test.py");
287        assert!(result.is_ok());
288        
289        let module = result.unwrap();
290        assert_eq!(module.raw.body.len(), 2); // Docstring + function
291    }
292
293    #[test]
294    fn test_module_name_generation() {
295        let result = parse("x = 1", "some_file.py");
296        assert!(result.is_ok());
297        
298        let module = result.unwrap();
299        assert!(module.name.is_some());
300        assert_eq!(module.name.unwrap().id, "some_file");
301    }
302
303    #[test]
304    fn test_module_name_with_path_separators() {
305        let code = "x = 1";
306        let filename = format!("path{}to{}module.py", std::path::MAIN_SEPARATOR, std::path::MAIN_SEPARATOR);
307        let result = parse(code, &filename);
308        assert!(result.is_ok());
309        
310        let module = result.unwrap();
311        assert!(module.name.is_some());
312        assert_eq!(module.name.unwrap().id, "path__to__module");
313    }
314}