python_ast/ast/tree/
call.rs

1use proc_macro2::TokenStream;
2use pyo3::{Bound, FromPyObject, PyAny, PyResult};
3use quote::quote;
4use serde::{Deserialize, Serialize};
5
6use crate::{CodeGen, CodeGenContext, ExprType, Keyword, PythonOptions, SymbolTableScopes, extract_required_attr};
7
8#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
9pub struct Call {
10    pub func: Box<ExprType>,
11    pub args: Vec<ExprType>,
12    pub keywords: Vec<Keyword>,
13}
14
15impl<'a> FromPyObject<'a> for Call {
16    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
17        let func: ExprType = extract_required_attr(ob, "func", "function call expression")?;
18        let args: Vec<ExprType> = extract_required_attr(ob, "args", "function call arguments")?;
19        let keywords: Vec<Keyword> = extract_required_attr(ob, "keywords", "function call keywords")?;
20        
21        Ok(Call {
22            func: Box::new(func),
23            args,
24            keywords,
25        })
26    }
27}
28
29impl<'a> CodeGen for Call {
30    type Context = CodeGenContext;
31    type Options = PythonOptions;
32    type SymbolTable = SymbolTableScopes;
33
34    fn to_rust(
35        self,
36        ctx: Self::Context,
37        options: Self::Options,
38        symbols: Self::SymbolTable,
39    ) -> Result<TokenStream, Box<dyn std::error::Error>> {
40        let name = self.func.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
41        
42        let mut all_args = Vec::new();
43        
44        // Add positional arguments
45        for arg in self.args {
46            let rust_arg = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
47            all_args.push(rust_arg);
48        }
49        
50        // Add keyword arguments
51        for keyword in self.keywords {
52            let rust_kw = keyword.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
53            all_args.push(rust_kw);
54        }
55        
56        // Check if we're in an async context and if the function being called is async
57        let call_expr = quote!(#name(#(#all_args),*));
58        
59        // Check if this function returns a Result that should be unwrapped
60        let name_str = format!("{}", name);
61        let needs_unwrap = matches!(name_str.as_str(), 
62            "subprocess :: run" | "subprocess :: run_with_env" | "subprocess :: check_call" | 
63            "subprocess :: check_output" | "os :: getcwd" | "os :: chdir" | "os :: execv" |
64            "os :: path :: abspath"
65        );
66        
67        // Special handling for subprocess.run and os.execv with fallback for compatibility
68        let final_call = if name_str == "subprocess :: run" {
69            // Try mixed_args version first, fallback to regular version
70            if all_args.len() >= 2 {
71                let args_param = &all_args[0];
72                let cwd_param = &all_args[1];
73                // Convert args to Vec<String> to avoid lifetime issues, then pass owned strings
74                quote!({
75                    let args_owned: Vec<String> = #args_param;
76                    let args_vec: Vec<&str> = args_owned.iter().map(|s| s.as_str()).collect();
77                    let cwd_str = #cwd_param;
78                    subprocess::run(args_vec, Some(&cwd_str)).unwrap()
79                })
80            } else {
81                let args_param = &all_args[0];
82                quote!({
83                    let args_owned: Vec<String> = #args_param;
84                    let args_vec: Vec<&str> = args_owned.iter().map(|s| s.as_str()).collect();
85                    subprocess::run(args_vec, None).unwrap()
86                })
87            }
88        } else if name_str == "os :: execv" {
89            // Convert to Vec<&str> for compatibility with standard execv function
90            let program_param = &all_args[0];
91            let args_param = &all_args[1];
92            quote!({
93                let program_str: String = (#program_param).clone();
94                let args_owned: Vec<String> = #args_param;
95                let args_vec: Vec<&str> = args_owned.iter().map(|s| s.as_str()).collect();
96                os::execv(&program_str, args_vec).unwrap()
97            })
98        } else if needs_unwrap {
99            quote!(#call_expr.unwrap())
100        } else {
101            call_expr
102        };
103        
104        match ctx {
105            CodeGenContext::Async(_) => {
106                // In async context, we assume Python async functions need .await
107                // We'll check if the function name suggests it's async
108                if name_str.contains("async") || 
109                   name_str.starts_with("a") || // Common async function naming
110                   // TODO: Better async function detection based on symbol table
111                   false {
112                    Ok(quote!(#final_call.await))
113                } else {
114                    // For now, just return the regular call
115                    // In a full implementation, we'd track which functions are async
116                    Ok(final_call)
117                }
118            },
119            _ => Ok(final_call)
120        }
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn test_lookup_of_function() {
130        let options = PythonOptions::default();
131        let result = crate::parse(
132            "def foo(a = 7):
133    pass
134
135foo(b=9)",
136            "test.py",
137        )
138        .unwrap();
139        let _code = result
140            .to_rust(
141                CodeGenContext::Module("test".to_string()),
142                options,
143                SymbolTableScopes::new(),
144            )
145            .unwrap();
146    }
147}