1use std::{collections::HashMap, default::Default};
2
3use log::info;
4use proc_macro2::TokenStream;
5use pyo3::{Bound, FromPyObject, PyAny, PyResult, prelude::PyAnyMethods};
6use quote::{format_ident, quote};
7use serde::{Deserialize, Serialize};
8
9use crate::{CodeGen, CodeGenContext, Name, Object, PythonOptions, Statement, StatementType, ExprType, SymbolTableScopes};
10
11
12#[derive(Clone, Debug, Serialize, Deserialize)]
13pub enum Type {
14 Unimplemented,
15}
16
17impl<'a> FromPyObject<'a> for Type {
18 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
19 info!("Type: {:?}", ob);
20 Ok(Type::Unimplemented)
21 }
22}
23
24#[derive(Clone, Debug, Default, FromPyObject, Serialize, Deserialize)]
26pub struct RawModule {
27 pub body: Vec<Statement>,
28 pub type_ignores: Vec<Type>,
29}
30
31#[derive(Clone, Debug, Default, Serialize, Deserialize)]
33pub struct Module {
34 pub raw: RawModule,
35 pub name: Option<Name>,
36 pub doc: Option<String>,
37 pub filename: Option<String>,
38 pub attributes: HashMap<Name, String>,
39}
40
41impl<'a> FromPyObject<'a> for Module {
42 fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
43 let raw_module = ob.extract().expect("Failed parsing module.");
44
45 Ok(Self {
46 raw: raw_module,
47 ..Default::default()
48 })
49 }
50}
51
52impl CodeGen for Module {
53 type Context = CodeGenContext;
54 type Options = PythonOptions;
55 type SymbolTable = SymbolTableScopes;
56
57 fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
58 let mut symbols = symbols;
59 symbols.new_scope();
60 for s in self.raw.body {
61 symbols = s.clone().find_symbols(symbols);
62 }
63 symbols
64 }
65
66 fn to_rust(
67 self,
68 ctx: Self::Context,
69 options: Self::Options,
70 symbols: Self::SymbolTable,
71 ) -> Result<TokenStream, Box<dyn std::error::Error>> {
72 let mut stream = TokenStream::new();
73
74 if let Some(docstring) = self.get_module_docstring() {
76 if self.raw.body.len() > 1 || self.looks_like_module_docstring() {
78 let doc_lines: Vec<_> = docstring
79 .lines()
80 .map(|line| {
81 if line.trim().is_empty() {
82 quote! { #![doc = ""] }
83 } else {
84 let doc_line = format!("{}", line);
85 quote! { #![doc = #doc_line] }
86 }
87 })
88 .collect();
89 stream.extend(quote! { #(#doc_lines)* });
90
91 let generated_comment = format!("Generated from Python file: {}",
93 self.filename.unwrap_or_else(|| "unknown.py".to_string()));
94 stream.extend(quote! { #![doc = #generated_comment] });
95 }
96 }
97
98 if options.with_std_python {
99 stream.extend(quote!(use stdpython::*;));
102 }
103
104 let needs_async_runtime = self.raw.body.iter().any(|s| {
107 matches!(&s.statement, crate::StatementType::AsyncFunctionDef(_))
108 });
109
110 if needs_async_runtime {
111 let runtime_import = format_ident!("{}", options.async_runtime.import());
112 stream.extend(quote!(use #runtime_import;));
113 }
114
115 let mut main_body_stmts = Vec::new();
116 let mut has_main_code = false;
117 let mut has_async_functions = false;
118 let mut module_init_stmts = Vec::new();
119 let mut has_module_init_code = false;
120 let mut is_simple_main_call_pattern = false;
121
122 for s in self.raw.body {
123 if let crate::StatementType::AsyncFunctionDef(_) = &s.statement {
125 has_async_functions = true;
126 }
127
128 if let crate::StatementType::If(if_stmt) = &s.statement {
130 let test_str = format!("{:?}", if_stmt.test);
131 if test_str.contains("__name__") && test_str.contains("__main__") {
132 let is_simple_main_call = Self::is_simple_main_call_block(&if_stmt.body);
134
135 if is_simple_main_call {
136 has_main_code = true;
139 is_simple_main_call_pattern = true;
140 } else {
142 for body_stmt in &if_stmt.body {
144 let stmt_token = body_stmt
145 .clone()
146 .to_rust(ctx.clone(), options.clone(), symbols.clone())
147 .expect("parsing if __name__ body statement");
148 if !stmt_token.to_string().trim().is_empty() {
149 main_body_stmts.push(stmt_token);
150 has_main_code = true;
151 }
152 }
153 }
154 continue;
156 }
157 }
158
159 let is_declaration = Self::is_declaration_statement(&s.statement);
161
162 let statement = s
163 .clone()
164 .to_rust(ctx.clone(), options.clone(), symbols.clone())
165 .expect(format!("parsing statement {:?} in module", s).as_str());
166
167 if statement.to_string() != "" {
168 if is_declaration {
169 stream.extend(statement);
171 } else {
172 module_init_stmts.push(statement);
174 has_module_init_code = true;
175 }
176 }
177 }
178
179 if has_module_init_code {
181 stream.extend(quote! {
182 fn __module_init__() {
183 #(#module_init_stmts)*
184 }
185 });
186 }
187
188 if has_main_code {
190 if is_simple_main_call_pattern {
191 let stream_str = stream.to_string();
194
195 let user_main_is_async = stream_str.contains("pub async fn main (");
197
198 if user_main_is_async {
199 let runtime_attr = options.async_runtime.main_attribute();
201 let attr_tokens: proc_macro2::TokenStream = runtime_attr.parse()
202 .unwrap_or_else(|_| quote!(tokio::main)); let new_stream_str = stream_str
206 .replace("pub async fn main (", &format!("#[{}] async fn main(", runtime_attr));
207 stream = new_stream_str.parse::<proc_macro2::TokenStream>()
208 .unwrap_or_else(|_| stream);
209
210 if has_module_init_code {
212 let renamed_stream_str = Self::rename_main_function_and_references(&stream_str);
215 stream = renamed_stream_str.parse::<proc_macro2::TokenStream>()
216 .unwrap_or_else(|_| stream);
217
218 stream.extend(quote! {
219 #[#attr_tokens]
220 async fn main() {
221 __module_init__();
222 python_main();
223 }
224 });
225 }
226 } else {
227 let new_stream_str = Self::convert_python_main_to_rust_entry_point(&stream_str);
230 stream = new_stream_str.parse::<proc_macro2::TokenStream>()
231 .unwrap_or_else(|_| stream);
232
233 if has_module_init_code {
235 let renamed_stream_str = Self::rename_main_function_and_references(&stream_str);
237 stream = renamed_stream_str.parse::<proc_macro2::TokenStream>()
238 .unwrap_or_else(|_| stream);
239
240 stream.extend(quote! {
241 fn main() {
242 __module_init__();
243 python_main();
244 }
245 });
246 }
247 }
248 } else {
249 let stream_str = stream.to_string();
251 let has_python_main = stream_str.contains("pub fn main (") || stream_str.contains("pub async fn main (");
252
253 if has_python_main {
254 let new_stream_str = Self::rename_main_function_and_references(&stream_str);
256 stream = new_stream_str.parse::<proc_macro2::TokenStream>()
257 .unwrap_or_else(|_| stream);
258
259 for stmt in &mut main_body_stmts {
261 let stmt_str = stmt.to_string();
262 let updated_stmt_str = Self::update_main_references(&stmt_str);
263 if updated_stmt_str != stmt_str {
264 if let Ok(new_stmt) = updated_stmt_str.parse::<proc_macro2::TokenStream>() {
265 *stmt = new_stmt;
266 }
267 }
268 }
269 }
270
271 if needs_async_runtime || has_async_functions {
273 let runtime_attr = options.async_runtime.main_attribute();
275 let attr_tokens: proc_macro2::TokenStream = runtime_attr.parse()
276 .unwrap_or_else(|_| quote!(tokio::main)); if has_module_init_code {
279 stream.extend(quote! {
280 #[#attr_tokens]
281 async fn main() {
282 __module_init__();
283 #(#main_body_stmts)*
284 }
285 });
286 } else {
287 stream.extend(quote! {
288 #[#attr_tokens]
289 async fn main() {
290 #(#main_body_stmts)*
291 }
292 });
293 }
294 } else {
295 if has_module_init_code {
296 stream.extend(quote! {
297 fn main() {
298 __module_init__();
299 #(#main_body_stmts)*
300 }
301 });
302 } else {
303 stream.extend(quote! {
304 fn main() {
305 #(#main_body_stmts)*
306 }
307 });
308 }
309 }
310 }
311 } else if has_module_init_code {
312 stream.extend(quote! {
315 fn main() {
316 __module_init__();
317 }
318 });
319 }
320 Ok(stream)
321 }
322}
323
324impl Module {
325 fn is_simple_main_call_block(body: &[crate::Statement]) -> bool {
331 if body.len() != 1 {
333 return false;
334 }
335
336 let stmt = &body[0];
337 match &stmt.statement {
338 crate::StatementType::Expr(expr) => {
340 Self::is_main_function_call(&expr.value)
341 },
342 crate::StatementType::Assign(assign) => {
344 assign.targets.len() == 1 && Self::is_main_function_call(&assign.value)
346 },
347 crate::StatementType::Call(call) => {
349 call.args.iter().any(|arg| Self::is_main_function_call(arg))
351 },
352 _ => false,
353 }
354 }
355
356 fn is_main_function_call(expr: &crate::ExprType) -> bool {
358 match expr {
359 crate::ExprType::Call(call) => {
360 match call.func.as_ref() {
361 crate::ExprType::Name(name) => name.id == "main",
362 _ => false,
363 }
364 },
365 _ => false,
366 }
367 }
368
369 fn is_declaration_statement(stmt_type: &crate::StatementType) -> bool {
371 use crate::StatementType::*;
372 match stmt_type {
373 FunctionDef(_) | AsyncFunctionDef(_) | ClassDef(_) | Import(_) | ImportFrom(_) => true,
375
376 Expr(expr) => Self::is_simple_expression(&expr.value),
379
380 Assign(_) | AugAssign(_) | Call(_) | Return(_) |
382 If(_) | For(_) | While(_) | Try(_) | With(_) | AsyncWith(_) | AsyncFor(_) |
383 Raise(_) | Pass | Break | Continue => false,
384
385 Unimplemented(_) => false,
387 }
388 }
389
390 fn is_simple_expression(expr: &crate::ExprType) -> bool {
392 use crate::ExprType::*;
393 match expr {
394 Constant(_) | Name(_) | NoneType(_) => true,
396
397 UnaryOp(_) => true,
399
400 Call(_) | BinOp(_) | Compare(_) | BoolOp(_) |
402 IfExp(_) | Dict(_) | Set(_) | List(_) | Tuple(_) | ListComp(_) |
403 Lambda(_) | Attribute(_) | Subscript(_) | Starred(_) |
404 DictComp(_) | SetComp(_) | GeneratorExp(_) | Await(_) |
405 Yield(_) | YieldFrom(_) | FormattedValue(_) | JoinedStr(_) |
406 NamedExpr(_) => false,
407
408 Unimplemented(_) | Unknown => false,
410 }
411 }
412
413 fn rename_main_function_and_references(code: &str) -> String {
415 let code = code
417 .replace("pub async fn main (", "pub async fn python_main (")
418 .replace("pub fn main (", "pub fn python_main (");
419
420 Self::update_main_references(&code)
422 }
423
424 fn convert_python_main_to_rust_entry_point(code: &str) -> String {
427 use regex::Regex;
428
429 let code = code.replace("pub fn main (", "fn main(");
431
432 let main_fn_pattern = Regex::new(r"fn main\(\s*\)\s*\{([^}]*)\}").unwrap();
435
436 if let Some(captures) = main_fn_pattern.captures(&code) {
437 let body = captures.get(1).map_or("", |m| m.as_str());
438
439 if body.contains("return ") {
441 let new_code = code.replace("fn main(", "fn python_main(");
443 format!("{}\n\nfn main() {{\n let _ = python_main();\n}}", new_code)
444 } else {
445 code
447 }
448 } else {
449 code
451 }
452 }
453
454 fn update_main_references(code: &str) -> String {
457 use regex::Regex;
458
459 let call_pattern = Regex::new(r"\bmain\s*\(").unwrap();
462 let mut result = call_pattern.replace_all(code, "python_main(").to_string();
463
464 let method_pattern = Regex::new(r"\.call_main\s*\(").unwrap();
466 result = method_pattern.replace_all(&result, ".call_python_main(").to_string();
467
468 let assignment_pattern = Regex::new(r"=\s+main\b").unwrap();
471 result = assignment_pattern.replace_all(&result, "= python_main").to_string();
472
473 let return_pattern = Regex::new(r"return\s+main\b").unwrap();
475 result = return_pattern.replace_all(&result, "return python_main").to_string();
476
477 result
478 }
479
480 fn get_module_docstring(&self) -> Option<String> {
481 if self.raw.body.is_empty() {
482 return None;
483 }
484
485 let first_stmt = &self.raw.body[0];
487 match &first_stmt.statement {
488 StatementType::Expr(expr) => match &expr.value {
489 ExprType::Constant(c) => {
490 let raw_string = c.to_string();
491 Some(self.format_module_docstring(&raw_string))
492 },
493 _ => None,
494 },
495 _ => None,
496 }
497 }
498
499 fn format_module_docstring(&self, raw: &str) -> String {
500 let content = raw.trim_matches('"');
502
503 let lines: Vec<&str> = content.lines().collect();
505 if lines.is_empty() {
506 return String::new();
507 }
508
509 let mut formatted = Vec::new();
511
512 for line in lines {
513 let cleaned = line.trim();
514 if !cleaned.is_empty() {
515 formatted.push(cleaned.to_string());
516 } else {
517 formatted.push(String::new());
518 }
519 }
520
521 formatted.join("\n")
522 }
523
524 fn looks_like_module_docstring(&self) -> bool {
525 if self.raw.body.is_empty() {
526 return false;
527 }
528
529 let first_stmt = &self.raw.body[0];
531 if let StatementType::Expr(expr) = &first_stmt.statement {
532 if let ExprType::Constant(c) = &expr.value {
533 let raw_string = c.to_string();
534 let content = raw_string.trim_matches('"');
535
536 return content.lines().count() > 1
541 || content.to_lowercase().contains("module")
542 || content.to_lowercase().contains("this ")
543 || content.len() > 50; }
545 }
546 false
547 }
548}
549
550impl Object for Module {
551 fn __dir__(&self) -> Vec<impl AsRef<str>> {
553 vec![
555 "__class__",
556 "__class_getitem__",
557 "__contains__",
558 "__delattr__",
559 "__delitem__",
560 "__dir__",
561 "__doc__",
562 "__eq__",
563 "__format__",
564 "__ge__",
565 "__getattribute__",
566 "__getitem__",
567 "__getstate__",
568 "__gt__",
569 "__hash__",
570 "__init__",
571 "__init_subclass__",
572 "__ior__",
573 "__iter__",
574 "__le__",
575 "__len__",
576 "__lt__",
577 "__ne__",
578 "__new__",
579 "__or__",
580 "__reduce__",
581 "__reduce_ex__",
582 "__repr__",
583 "__reversed__",
584 "__ror__",
585 "__setattr__",
586 "__setitem__",
587 "__sizeof__",
588 "__str__",
589 "__subclasshook__",
590 "clear",
591 "copy",
592 "fromkeys",
593 "get",
594 "items",
595 "keys",
596 "pop",
597 "popitem",
598 "setdefault",
599 "update",
600 "values",
601 ]
602 }
603}
604
605#[cfg(test)]
606mod tests {
607 use super::*;
608
609 #[test]
610 fn can_we_print() {
611 let options = PythonOptions::default();
612 let result = crate::parse(
613 "#test comment
614def foo():
615 print(\"Test print.\")
616",
617 "test_case.py",
618 )
619 .unwrap();
620 info!("Python tree: {:?}", result);
621 let code = result.to_rust(
624 CodeGenContext::Module("test_case".to_string()),
625 options,
626 SymbolTableScopes::new(),
627 );
628 info!("module: {:?}", code);
629 }
630
631 #[test]
632 fn can_we_import() {
633 let result = crate::parse("import ast", "ast.py").unwrap();
634 let options = PythonOptions::default();
635 info!("{:?}", result);
636
637 let code = result.to_rust(
638 CodeGenContext::Module("test_case".to_string()),
639 options,
640 SymbolTableScopes::new(),
641 );
642 info!("module: {:?}", code);
643 }
644
645 #[test]
646 fn can_we_import2() {
647 let result = crate::parse("import ast as test", "ast.py").unwrap();
648 let options = PythonOptions::default();
649 info!("{:?}", result);
650
651 let code = result.to_rust(
652 CodeGenContext::Module("test_case".to_string()),
653 options,
654 SymbolTableScopes::new(),
655 );
656 info!("module: {:?}", code);
657 }
658}