python_ast/ast/tree/
call.rs1use proc_macro2::TokenStream;
2use pyo3::{Bound, FromPyObject, PyAny, PyResult};
3use quote::quote;
4use serde::{Deserialize, Serialize};
5
6use crate::{CodeGen, CodeGenContext, ExprType, Keyword, PythonOptions, SymbolTableScopes, extract_required_attr};
7
8#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
9pub struct Call {
10 pub func: Box<ExprType>,
11 pub args: Vec<ExprType>,
12 pub keywords: Vec<Keyword>,
13}
14
15impl<'a> FromPyObject<'a> for Call {
16 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
17 let func: ExprType = extract_required_attr(ob, "func", "function call expression")?;
18 let args: Vec<ExprType> = extract_required_attr(ob, "args", "function call arguments")?;
19 let keywords: Vec<Keyword> = extract_required_attr(ob, "keywords", "function call keywords")?;
20
21 Ok(Call {
22 func: Box::new(func),
23 args,
24 keywords,
25 })
26 }
27}
28
29impl<'a> CodeGen for Call {
30 type Context = CodeGenContext;
31 type Options = PythonOptions;
32 type SymbolTable = SymbolTableScopes;
33
34 fn to_rust(
35 self,
36 ctx: Self::Context,
37 options: Self::Options,
38 symbols: Self::SymbolTable,
39 ) -> Result<TokenStream, Box<dyn std::error::Error>> {
40 let name = self.func.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
41
42 let mut all_args = Vec::new();
43
44 for arg in self.args {
46 let rust_arg = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
47 all_args.push(rust_arg);
48 }
49
50 for keyword in self.keywords {
52 let rust_kw = keyword.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
53 all_args.push(rust_kw);
54 }
55
56 let call_expr = quote!(#name(#(#all_args),*));
58
59 let name_str = format!("{}", name);
61 let needs_unwrap = matches!(name_str.as_str(),
62 "subprocess :: run" | "subprocess :: run_with_env" | "subprocess :: check_call" |
63 "subprocess :: check_output" | "os :: getcwd" | "os :: chdir" | "os :: execv" |
64 "os :: path :: abspath"
65 );
66
67 let final_call = if name_str == "subprocess :: run" {
69 if all_args.len() >= 2 {
71 let args_param = &all_args[0];
72 let cwd_param = &all_args[1];
73 quote!({
75 let args_owned: Vec<String> = #args_param;
76 let args_vec: Vec<&str> = args_owned.iter().map(|s| s.as_str()).collect();
77 let cwd_str = #cwd_param;
78 subprocess::run(args_vec, Some(&cwd_str)).unwrap()
79 })
80 } else {
81 let args_param = &all_args[0];
82 quote!({
83 let args_owned: Vec<String> = #args_param;
84 let args_vec: Vec<&str> = args_owned.iter().map(|s| s.as_str()).collect();
85 subprocess::run(args_vec, None).unwrap()
86 })
87 }
88 } else if name_str == "os :: execv" {
89 let program_param = &all_args[0];
91 let args_param = &all_args[1];
92 quote!({
93 let program_str: String = (#program_param).clone();
94 let args_owned: Vec<String> = #args_param;
95 let args_vec: Vec<&str> = args_owned.iter().map(|s| s.as_str()).collect();
96 os::execv(&program_str, args_vec).unwrap()
97 })
98 } else if needs_unwrap {
99 quote!(#call_expr.unwrap())
100 } else {
101 call_expr
102 };
103
104 match ctx {
105 CodeGenContext::Async(_) => {
106 if name_str.contains("async") ||
109 name_str.starts_with("a") || false {
112 Ok(quote!(#final_call.await))
113 } else {
114 Ok(final_call)
117 }
118 },
119 _ => Ok(final_call)
120 }
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn test_lookup_of_function() {
130 let options = PythonOptions::default();
131 let result = crate::parse(
132 "def foo(a = 7):
133 pass
134
135foo(b=9)",
136 "test.py",
137 )
138 .unwrap();
139 let _code = result
140 .to_rust(
141 CodeGenContext::Module("test".to_string()),
142 options,
143 SymbolTableScopes::new(),
144 )
145 .unwrap();
146 }
147}