use proc_macro2::TokenStream;
use pyo3::{Bound, FromPyObject, PyAny, PyResult};
use quote::quote;
use serde::{Deserialize, Serialize};
use crate::{CodeGen, CodeGenContext, ExprType, Keyword, PythonOptions, SymbolTableScopes, extract_required_attr};
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct Call {
pub func: Box<ExprType>,
pub args: Vec<ExprType>,
pub keywords: Vec<Keyword>,
}
impl<'a> FromPyObject<'a> for Call {
fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
let func: ExprType = extract_required_attr(ob, "func", "function call expression")?;
let args: Vec<ExprType> = extract_required_attr(ob, "args", "function call arguments")?;
let keywords: Vec<Keyword> = extract_required_attr(ob, "keywords", "function call keywords")?;
Ok(Call {
func: Box::new(func),
args,
keywords,
})
}
}
impl<'a> CodeGen for Call {
type Context = CodeGenContext;
type Options = PythonOptions;
type SymbolTable = SymbolTableScopes;
fn to_rust(
self,
ctx: Self::Context,
options: Self::Options,
symbols: Self::SymbolTable,
) -> Result<TokenStream, Box<dyn std::error::Error>> {
let name = self.func.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
let mut all_args = Vec::new();
for arg in self.args {
let rust_arg = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
all_args.push(rust_arg);
}
for keyword in self.keywords {
let rust_kw = keyword.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
all_args.push(rust_kw);
}
let call_expr = quote!(#name(#(#all_args),*));
let name_str = format!("{}", name);
let needs_unwrap = matches!(name_str.as_str(),
"subprocess :: run" | "subprocess :: run_with_env" | "subprocess :: check_call" |
"subprocess :: check_output" | "os :: getcwd" | "os :: chdir" | "os :: execv" |
"os :: path :: abspath"
);
let final_call = if name_str == "subprocess :: run" {
if all_args.len() >= 2 {
let args_param = &all_args[0];
let cwd_param = &all_args[1];
quote!({
let args_owned: Vec<String> = #args_param;
let args_vec: Vec<&str> = args_owned.iter().map(|s| s.as_str()).collect();
let cwd_str = #cwd_param;
subprocess::run(args_vec, Some(&cwd_str)).unwrap()
})
} else {
let args_param = &all_args[0];
quote!({
let args_owned: Vec<String> = #args_param;
let args_vec: Vec<&str> = args_owned.iter().map(|s| s.as_str()).collect();
subprocess::run(args_vec, None).unwrap()
})
}
} else if name_str == "os :: execv" {
let program_param = &all_args[0];
let args_param = &all_args[1];
quote!({
let program_str: String = (#program_param).clone();
let args_owned: Vec<String> = #args_param;
let args_vec: Vec<&str> = args_owned.iter().map(|s| s.as_str()).collect();
os::execv(&program_str, args_vec).unwrap()
})
} else if needs_unwrap {
quote!(#call_expr.unwrap())
} else {
call_expr
};
match ctx {
CodeGenContext::Async(_) => {
if name_str.contains("async") ||
name_str.starts_with("a") || false {
Ok(quote!(#final_call.await))
} else {
Ok(final_call)
}
},
_ => Ok(final_call)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lookup_of_function() {
let options = PythonOptions::default();
let result = crate::parse(
"def foo(a = 7):
pass
foo(b=9)",
"test.py",
)
.unwrap();
let _code = result
.to_rust(
CodeGenContext::Module("test".to_string()),
options,
SymbolTableScopes::new(),
)
.unwrap();
}
}