python_ast/ast/tree/
module.rs

1use std::{collections::HashMap, default::Default};
2
3use log::info;
4use proc_macro2::TokenStream;
5use pyo3::{Bound, FromPyObject, PyAny, PyResult, prelude::PyAnyMethods};
6use quote::{format_ident, quote};
7use serde::{Deserialize, Serialize};
8
9use crate::{CodeGen, CodeGenContext, Name, Object, PythonOptions, Statement, StatementType, ExprType, SymbolTableScopes};
10
11
12#[derive(Clone, Debug, Serialize, Deserialize)]
13pub enum Type {
14    Unimplemented,
15}
16
17impl<'a> FromPyObject<'a> for Type {
18    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
19        info!("Type: {:?}", ob);
20        Ok(Type::Unimplemented)
21    }
22}
23
24/// Represents a module as imported from an ast. See the Module struct for the processed module.
25#[derive(Clone, Debug, Default, FromPyObject, Serialize, Deserialize)]
26pub struct RawModule {
27    pub body: Vec<Statement>,
28    pub type_ignores: Vec<Type>,
29}
30
31/// Represents a module as imported from an ast.
32#[derive(Clone, Debug, Default, Serialize, Deserialize)]
33pub struct Module {
34    pub raw: RawModule,
35    pub name: Option<Name>,
36    pub doc: Option<String>,
37    pub filename: Option<String>,
38    pub attributes: HashMap<Name, String>,
39}
40
41impl<'a> FromPyObject<'a> for Module {
42    fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
43        let raw_module = ob.extract().expect("Failed parsing module.");
44
45        Ok(Self {
46            raw: raw_module,
47            ..Default::default()
48        })
49    }
50}
51
52impl CodeGen for Module {
53    type Context = CodeGenContext;
54    type Options = PythonOptions;
55    type SymbolTable = SymbolTableScopes;
56
57    fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
58        let mut symbols = symbols;
59        symbols.new_scope();
60        for s in self.raw.body {
61            symbols = s.clone().find_symbols(symbols);
62        }
63        symbols
64    }
65
66    fn to_rust(
67        self,
68        ctx: Self::Context,
69        options: Self::Options,
70        symbols: Self::SymbolTable,
71    ) -> Result<TokenStream, Box<dyn std::error::Error>> {
72        let mut stream = TokenStream::new();
73        
74        // Add module-level documentation if available and not just an expression
75        if let Some(docstring) = self.get_module_docstring() {
76            // Only add module docs if there are multiple statements or if this seems to be a real module docstring
77            if self.raw.body.len() > 1 || self.looks_like_module_docstring() {
78                let doc_lines: Vec<_> = docstring
79                    .lines()
80                    .map(|line| {
81                        if line.trim().is_empty() {
82                            quote! { #![doc = ""] }
83                        } else {
84                            let doc_line = format!("{}", line);
85                            quote! { #![doc = #doc_line] }
86                        }
87                    })
88                    .collect();
89                stream.extend(quote! { #(#doc_lines)* });
90                
91                // Add generated by comment only when we have actual module docs
92                let generated_comment = format!("Generated from Python file: {}", 
93                    self.filename.unwrap_or_else(|| "unknown.py".to_string()));
94                stream.extend(quote! { #![doc = #generated_comment] });
95            }
96        }
97        
98        if options.with_std_python {
99            // For imports, always use "stdpython" since that's the actual crate name
100            // The runtime specification is just for dependency management
101            stream.extend(quote!(use stdpython::*;));
102        }
103        
104        // Add async runtime dependency if async functions are detected
105        // We'll check this early so we can add the import at the top
106        let needs_async_runtime = self.raw.body.iter().any(|s| {
107            matches!(&s.statement, crate::StatementType::AsyncFunctionDef(_))
108        });
109        
110        if needs_async_runtime {
111            let runtime_import = format_ident!("{}", options.async_runtime.import());
112            stream.extend(quote!(use #runtime_import;));
113        }
114        
115        let mut main_body_stmts = Vec::new();
116        let mut has_main_code = false;
117        let mut has_async_functions = false;
118        let mut module_init_stmts = Vec::new();
119        let mut has_module_init_code = false;
120        let mut is_simple_main_call_pattern = false;
121        
122        for s in self.raw.body {
123            // Check if this statement is an async function
124            if let crate::StatementType::AsyncFunctionDef(_) = &s.statement {
125                has_async_functions = true;
126            }
127            
128            // Check for if __name__ == "__main__" blocks at the AST level before generating code
129            if let crate::StatementType::If(if_stmt) = &s.statement {
130                let test_str = format!("{:?}", if_stmt.test);
131                if test_str.contains("__name__") && test_str.contains("__main__") {
132                    // Check if this is a simple main() call pattern
133                    let is_simple_main_call = Self::is_simple_main_call_block(&if_stmt.body);
134                    
135                    if is_simple_main_call {
136                        // For simple main() calls, we'll use the user's main function directly
137                        // Set a flag to indicate we should not rename the main function
138                        has_main_code = true;
139                        is_simple_main_call_pattern = true;
140                        // Don't collect the main body statements - we'll use user's main directly
141                    } else {
142                        // This is a complex __name__ == "__main__" block - collect its body for main function
143                        for body_stmt in &if_stmt.body {
144                            let stmt_token = body_stmt
145                                .clone()
146                                .to_rust(ctx.clone(), options.clone(), symbols.clone())
147                                .expect("parsing if __name__ body statement");
148                            if !stmt_token.to_string().trim().is_empty() {
149                                main_body_stmts.push(stmt_token);
150                                has_main_code = true;
151                            }
152                        }
153                    }
154                    // Skip generating this if statement - we've processed its contents
155                    continue;
156                }
157            }
158            
159            // Categorize statements into declarations vs executable code
160            let is_declaration = Self::is_declaration_statement(&s.statement);
161            
162            let statement = s
163                .clone()
164                .to_rust(ctx.clone(), options.clone(), symbols.clone())
165                .expect(format!("parsing statement {:?} in module", s).as_str());
166            
167            if statement.to_string() != "" {
168                if is_declaration {
169                    // Declarations go at module level (functions, classes, imports)
170                    stream.extend(statement);
171                } else {
172                    // Executable statements go in module initialization function
173                    module_init_stmts.push(statement);
174                    has_module_init_code = true;
175                }
176            }
177        }
178        
179        // Generate module initialization function if needed
180        if has_module_init_code {
181            stream.extend(quote! {
182                fn __module_init__() {
183                    #(#module_init_stmts)*
184                }
185            });
186        }
187        
188        // If we collected any main code, generate a single consolidated main function
189        if has_main_code {
190            if is_simple_main_call_pattern {
191                // Simple main() call pattern - use user's main function directly as Rust entry point
192                // Don't rename the user's main function, just add module init call if needed
193                let stream_str = stream.to_string();
194                
195                // Check if the user's main function is async
196                let user_main_is_async = stream_str.contains("pub async fn main (");
197                
198                if user_main_is_async {
199                    // User's async main becomes the Rust entry point
200                    let runtime_attr = options.async_runtime.main_attribute();
201                    let attr_tokens: proc_macro2::TokenStream = runtime_attr.parse()
202                        .unwrap_or_else(|_| quote!(tokio::main)); // fallback to tokio::main
203                    
204                    // Replace the user's function signature and add attributes
205                    let new_stream_str = stream_str
206                        .replace("pub async fn main (", &format!("#[{}] async fn main(", runtime_attr));
207                    stream = new_stream_str.parse::<proc_macro2::TokenStream>()
208                        .unwrap_or_else(|_| stream);
209                        
210                    // If we have module init code, we need to modify the user's main to call it first
211                    if has_module_init_code {
212                        // This is more complex - we'd need to modify the user's main function body
213                        // For now, let's fall back to the rename approach for async functions with module init
214                        let renamed_stream_str = Self::rename_main_function_and_references(&stream_str);
215                        stream = renamed_stream_str.parse::<proc_macro2::TokenStream>()
216                            .unwrap_or_else(|_| stream);
217                        
218                        stream.extend(quote! {
219                            #[#attr_tokens]
220                            async fn main() {
221                                __module_init__();
222                                python_main();
223                            }
224                        });
225                    }
226                } else {
227                    // User's sync main becomes the Rust entry point
228                    // Need to modify the function to match Rust main signature requirements
229                    let new_stream_str = Self::convert_python_main_to_rust_entry_point(&stream_str);
230                    stream = new_stream_str.parse::<proc_macro2::TokenStream>()
231                        .unwrap_or_else(|_| stream);
232                    
233                    // If we have module init code, we need to modify the user's main to call it first
234                    if has_module_init_code {
235                        // For simplicity, we'll use the rename approach when module init is needed
236                        let renamed_stream_str = Self::rename_main_function_and_references(&stream_str);
237                        stream = renamed_stream_str.parse::<proc_macro2::TokenStream>()
238                            .unwrap_or_else(|_| stream);
239                        
240                        stream.extend(quote! {
241                            fn main() {
242                                __module_init__();
243                                python_main();
244                            }
245                        });
246                    }
247                }
248            } else {
249                // Complex main block - use existing behavior (rename user's main)
250                let stream_str = stream.to_string();
251                let has_python_main = stream_str.contains("pub fn main (") || stream_str.contains("pub async fn main (");
252                
253                if has_python_main {
254                    // Rename the Python function to avoid conflict with Rust entry point
255                    let new_stream_str = Self::rename_main_function_and_references(&stream_str);
256                    stream = new_stream_str.parse::<proc_macro2::TokenStream>()
257                        .unwrap_or_else(|_| stream);
258                    
259                    // Update main_body_stmts to call python_main instead of main
260                    for stmt in &mut main_body_stmts {
261                        let stmt_str = stmt.to_string();
262                        let updated_stmt_str = Self::update_main_references(&stmt_str);
263                        if updated_stmt_str != stmt_str {
264                            if let Ok(new_stmt) = updated_stmt_str.parse::<proc_macro2::TokenStream>() {
265                                *stmt = new_stmt;
266                            }
267                        }
268                    }
269                }
270                
271                // Generate the Rust entry point as main() - async if needed
272                if needs_async_runtime || has_async_functions {
273                    // Parse the runtime attribute string into tokens
274                    let runtime_attr = options.async_runtime.main_attribute();
275                    let attr_tokens: proc_macro2::TokenStream = runtime_attr.parse()
276                        .unwrap_or_else(|_| quote!(tokio::main)); // fallback to tokio::main
277                    
278                    if has_module_init_code {
279                        stream.extend(quote! {
280                            #[#attr_tokens]
281                            async fn main() {
282                                __module_init__();
283                                #(#main_body_stmts)*
284                            }
285                        });
286                    } else {
287                        stream.extend(quote! {
288                            #[#attr_tokens]
289                            async fn main() {
290                                #(#main_body_stmts)*
291                            }
292                        });
293                    }
294                } else {
295                    if has_module_init_code {
296                        stream.extend(quote! {
297                            fn main() {
298                                __module_init__();
299                                #(#main_body_stmts)*
300                            }
301                        });
302                    } else {
303                        stream.extend(quote! {
304                            fn main() {
305                                #(#main_body_stmts)*
306                            }
307                        });
308                    }
309                }
310            }
311        } else if has_module_init_code {
312            // No main block, but we have module initialization code
313            // Generate a main function that just runs module initialization
314            stream.extend(quote! {
315                fn main() {
316                    __module_init__();
317                }
318            });
319        }
320        Ok(stream)
321    }
322}
323
324impl Module {
325    /// Check if the __name__ == "__main__" block contains only a simple call to main()
326    /// This includes patterns like:
327    /// - main()
328    /// - result = main()
329    /// - sys.exit(main())
330    fn is_simple_main_call_block(body: &[crate::Statement]) -> bool {
331        // Must have exactly one statement
332        if body.len() != 1 {
333            return false;
334        }
335        
336        let stmt = &body[0];
337        match &stmt.statement {
338            // Pattern 1: main() - direct call as expression statement
339            crate::StatementType::Expr(expr) => {
340                Self::is_main_function_call(&expr.value)
341            },
342            // Pattern 2: result = main() - assignment from main call
343            crate::StatementType::Assign(assign) => {
344                // Should have exactly one target and the value should be a main() call
345                assign.targets.len() == 1 && Self::is_main_function_call(&assign.value)
346            },
347            // Pattern 3: sys.exit(main()) - call with main() as argument
348            crate::StatementType::Call(call) => {
349                // Check if any of the arguments is a main() call
350                call.args.iter().any(|arg| Self::is_main_function_call(arg))
351            },
352            _ => false,
353        }
354    }
355    
356    /// Check if an expression is a call to a function named "main"
357    fn is_main_function_call(expr: &crate::ExprType) -> bool {
358        match expr {
359            crate::ExprType::Call(call) => {
360                match call.func.as_ref() {
361                    crate::ExprType::Name(name) => name.id == "main",
362                    _ => false,
363                }
364            },
365            _ => false,
366        }
367    }
368    
369    /// Determine if a statement is a declaration (can stay at module level) or executable code (needs to go in init function)
370    fn is_declaration_statement(stmt_type: &crate::StatementType) -> bool {
371        use crate::StatementType::*;
372        match stmt_type {
373            // These are declarations that can stay at module level
374            FunctionDef(_) | AsyncFunctionDef(_) | ClassDef(_) | Import(_) | ImportFrom(_) => true,
375            
376            // Standalone expressions can stay at module level (e.g., constants, simple values)
377            // These are typically used in tests or simple modules
378            Expr(expr) => Self::is_simple_expression(&expr.value),
379            
380            // These are executable statements that must go in the init function
381            Assign(_) | AugAssign(_) | Call(_) | Return(_) |
382            If(_) | For(_) | While(_) | Try(_) | With(_) | AsyncWith(_) | AsyncFor(_) |
383            Raise(_) | Pass | Break | Continue => false,
384            
385            // Handle unimplemented statements conservatively as executable
386            Unimplemented(_) => false,
387        }
388    }
389    
390    /// Check if an expression is simple enough to remain at module level
391    fn is_simple_expression(expr: &crate::ExprType) -> bool {
392        use crate::ExprType::*;
393        match expr {
394            // Simple constants and literals can stay at module level
395            Constant(_) | Name(_) | NoneType(_) => true,
396            
397            // Allow unary operations for single-expression modules (test compatibility)
398            UnaryOp(_) => true,
399            
400            // Function calls and complex expressions should go in init
401            Call(_) | BinOp(_) | Compare(_) | BoolOp(_) | 
402            IfExp(_) | Dict(_) | Set(_) | List(_) | Tuple(_) | ListComp(_) |
403            Lambda(_) | Attribute(_) | Subscript(_) | Starred(_) |
404            DictComp(_) | SetComp(_) | GeneratorExp(_) | Await(_) | 
405            Yield(_) | YieldFrom(_) | FormattedValue(_) | JoinedStr(_) |
406            NamedExpr(_) => false,
407            
408            // Be conservative about other expression types
409            Unimplemented(_) | Unknown => false,
410        }
411    }
412    
413    /// Rename the main function definition and update all references to it throughout the code
414    fn rename_main_function_and_references(code: &str) -> String {
415        // First, rename the function definitions
416        let code = code
417            .replace("pub async fn main (", "pub async fn python_main (")
418            .replace("pub fn main (", "pub fn python_main (");
419        
420        // Then update all references using the comprehensive reference updater
421        Self::update_main_references(&code)
422    }
423    
424    /// Convert a Python main function to be suitable as a Rust entry point
425    /// This handles return value conversion and signature requirements
426    fn convert_python_main_to_rust_entry_point(code: &str) -> String {
427        use regex::Regex;
428        
429        // Replace "pub fn main (" with "fn main("
430        let code = code.replace("pub fn main (", "fn main(");
431        
432        // Handle return statements in the main function
433        // We need to wrap the function body to ignore return values
434        let main_fn_pattern = Regex::new(r"fn main\(\s*\)\s*\{([^}]*)\}").unwrap();
435        
436        if let Some(captures) = main_fn_pattern.captures(&code) {
437            let body = captures.get(1).map_or("", |m| m.as_str());
438            
439            // Check if the body contains return statements
440            if body.contains("return ") {
441                // Wrap the original function as python_main and create new main that ignores return
442                let new_code = code.replace("fn main(", "fn python_main(");
443                format!("{}\n\nfn main() {{\n    let _ = python_main();\n}}", new_code)
444            } else {
445                // No return statements, use the function as-is
446                code
447            }
448        } else {
449            // Couldn't parse the function, fall back to original
450            code
451        }
452    }
453    
454    /// Update all references to main() function calls with python_main() calls
455    /// This uses regex to handle various call patterns with parameters
456    fn update_main_references(code: &str) -> String {
457        use regex::Regex;
458        
459        // Pattern 1: main(...) - function calls with any arguments (including empty)
460        // This pattern matches "main(" and lets us replace the function name
461        let call_pattern = Regex::new(r"\bmain\s*\(").unwrap();
462        let mut result = call_pattern.replace_all(code, "python_main(").to_string();
463        
464        // Pattern 2: Handle method calls like obj.call_main() -> obj.call_python_main()
465        let method_pattern = Regex::new(r"\.call_main\s*\(").unwrap();
466        result = method_pattern.replace_all(&result, ".call_python_main(").to_string();
467        
468        // Pattern 3: Handle assignment patterns like "result = main" (without parentheses)
469        // We need to be careful not to match function definitions or other contexts
470        let assignment_pattern = Regex::new(r"=\s+main\b").unwrap();
471        result = assignment_pattern.replace_all(&result, "= python_main").to_string();
472        
473        // Pattern 4: Handle return statements like "return main"
474        let return_pattern = Regex::new(r"return\s+main\b").unwrap();
475        result = return_pattern.replace_all(&result, "return python_main").to_string();
476        
477        result
478    }
479    
480    fn get_module_docstring(&self) -> Option<String> {
481        if self.raw.body.is_empty() {
482            return None;
483        }
484        
485        // Check if the first statement is a string constant (docstring)
486        let first_stmt = &self.raw.body[0];
487        match &first_stmt.statement {
488            StatementType::Expr(expr) => match &expr.value {
489                ExprType::Constant(c) => {
490                    let raw_string = c.to_string();
491                    Some(self.format_module_docstring(&raw_string))
492                },
493                _ => None,
494            },
495            _ => None,
496        }
497    }
498    
499    fn format_module_docstring(&self, raw: &str) -> String {
500        // Remove surrounding quotes
501        let content = raw.trim_matches('"');
502        
503        // Split into lines and clean up Python-style indentation
504        let lines: Vec<&str> = content.lines().collect();
505        if lines.is_empty() {
506            return String::new();
507        }
508        
509        // For module docstrings, preserve more of the original formatting
510        let mut formatted = Vec::new();
511        
512        for line in lines {
513            let cleaned = line.trim();
514            if !cleaned.is_empty() {
515                formatted.push(cleaned.to_string());
516            } else {
517                formatted.push(String::new());
518            }
519        }
520        
521        formatted.join("\n")
522    }
523    
524    fn looks_like_module_docstring(&self) -> bool {
525        if self.raw.body.is_empty() {
526            return false;
527        }
528        
529        // Check if the first statement looks like a module docstring
530        let first_stmt = &self.raw.body[0];
531        if let StatementType::Expr(expr) = &first_stmt.statement {
532            if let ExprType::Constant(c) = &expr.value {
533                let raw_string = c.to_string();
534                let content = raw_string.trim_matches('"');
535                
536                // Heuristics to detect if this is a module docstring vs just a string expression:
537                // 1. Contains multiple lines
538                // 2. Contains common docstring keywords
539                // 3. Looks like documentation rather than a simple string
540                return content.lines().count() > 1 
541                    || content.to_lowercase().contains("module")
542                    || content.to_lowercase().contains("this ")
543                    || content.len() > 50; // Longer strings are more likely to be docstrings
544            }
545        }
546        false
547    }
548}
549
550impl Object for Module {
551    /// __dir__ is called to list the attributes of the object.
552    fn __dir__(&self) -> Vec<impl AsRef<str>> {
553        // XXX - Make this meaningful.
554        vec![
555            "__class__",
556            "__class_getitem__",
557            "__contains__",
558            "__delattr__",
559            "__delitem__",
560            "__dir__",
561            "__doc__",
562            "__eq__",
563            "__format__",
564            "__ge__",
565            "__getattribute__",
566            "__getitem__",
567            "__getstate__",
568            "__gt__",
569            "__hash__",
570            "__init__",
571            "__init_subclass__",
572            "__ior__",
573            "__iter__",
574            "__le__",
575            "__len__",
576            "__lt__",
577            "__ne__",
578            "__new__",
579            "__or__",
580            "__reduce__",
581            "__reduce_ex__",
582            "__repr__",
583            "__reversed__",
584            "__ror__",
585            "__setattr__",
586            "__setitem__",
587            "__sizeof__",
588            "__str__",
589            "__subclasshook__",
590            "clear",
591            "copy",
592            "fromkeys",
593            "get",
594            "items",
595            "keys",
596            "pop",
597            "popitem",
598            "setdefault",
599            "update",
600            "values",
601        ]
602    }
603}
604
605#[cfg(test)]
606mod tests {
607    use super::*;
608
609    #[test]
610    fn can_we_print() {
611        let options = PythonOptions::default();
612        let result = crate::parse(
613            "#test comment
614def foo():
615    print(\"Test print.\")
616",
617            "test_case.py",
618        )
619        .unwrap();
620        info!("Python tree: {:?}", result);
621        //info!("{}", result);
622
623        let code = result.to_rust(
624            CodeGenContext::Module("test_case".to_string()),
625            options,
626            SymbolTableScopes::new(),
627        );
628        info!("module: {:?}", code);
629    }
630
631    #[test]
632    fn can_we_import() {
633        let result = crate::parse("import ast", "ast.py").unwrap();
634        let options = PythonOptions::default();
635        info!("{:?}", result);
636
637        let code = result.to_rust(
638            CodeGenContext::Module("test_case".to_string()),
639            options,
640            SymbolTableScopes::new(),
641        );
642        info!("module: {:?}", code);
643    }
644
645    #[test]
646    fn can_we_import2() {
647        let result = crate::parse("import ast as test", "ast.py").unwrap();
648        let options = PythonOptions::default();
649        info!("{:?}", result);
650
651        let code = result.to_rust(
652            CodeGenContext::Module("test_case".to_string()),
653            options,
654            SymbolTableScopes::new(),
655        );
656        info!("module: {:?}", code);
657    }
658}