use proc_macro2::TokenStream;
use pyo3::{Bound, FromPyObject, PyAny, PyResult, prelude::PyAnyMethods};
use quote::quote;
use serde::{Deserialize, Serialize};
use crate::{
CodeGen, CodeGenContext, ExprType, Node, PythonOptions, SymbolTableScopes,
};
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct Argument {
pub value: ExprType,
pub lineno: Option<usize>,
pub col_offset: Option<usize>,
pub end_lineno: Option<usize>,
pub end_col_offset: Option<usize>,
}
pub type Arg = ExprType;
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct Parameter {
pub arg: String,
pub annotation: Option<Box<ExprType>>,
pub type_comment: Option<String>,
pub lineno: Option<usize>,
pub col_offset: Option<usize>,
pub end_lineno: Option<usize>,
pub end_col_offset: Option<usize>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct Arguments {
pub posonlyargs: Vec<Parameter>,
pub args: Vec<Parameter>,
pub vararg: Option<Parameter>,
pub kwonlyargs: Vec<Parameter>,
pub kw_defaults: Vec<Option<Box<ExprType>>>,
pub kwarg: Option<Parameter>,
pub defaults: Vec<Box<ExprType>>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct CallArguments {
pub args: Vec<ExprType>,
pub keywords: Vec<crate::Keyword>,
}
impl<'a> FromPyObject<'a> for Argument {
fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
let value: ExprType = ob.extract()?;
Ok(Self {
value,
lineno: ob.lineno(),
col_offset: ob.col_offset(),
end_lineno: ob.end_lineno(),
end_col_offset: ob.end_col_offset(),
})
}
}
impl CodeGen for Argument {
type Context = CodeGenContext;
type Options = PythonOptions;
type SymbolTable = SymbolTableScopes;
fn to_rust(
self,
ctx: Self::Context,
options: Self::Options,
symbols: Self::SymbolTable,
) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
self.value.to_rust(ctx, options, symbols)
}
}
impl<'a> FromPyObject<'a> for Parameter {
fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
let arg: String = ob.getattr("arg")?.extract()?;
let annotation = if let Ok(ann) = ob.getattr("annotation") {
if ann.is_none() {
None
} else {
Some(Box::new(ann.extract()?))
}
} else {
None
};
let type_comment = if let Ok(tc) = ob.getattr("type_comment") {
if tc.is_none() {
None
} else {
Some(tc.extract()?)
}
} else {
None
};
Ok(Self {
arg,
annotation,
type_comment,
lineno: ob.lineno(),
col_offset: ob.col_offset(),
end_lineno: ob.end_lineno(),
end_col_offset: ob.end_col_offset(),
})
}
}
impl CodeGen for Parameter {
type Context = CodeGenContext;
type Options = PythonOptions;
type SymbolTable = SymbolTableScopes;
fn to_rust(
self,
ctx: Self::Context,
options: Self::Options,
symbols: Self::SymbolTable,
) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
use quote::format_ident;
let param_name = format_ident!("{}", self.arg);
if let Some(annotation) = self.annotation {
let rust_type = annotation.to_rust(ctx, options, symbols)?;
Ok(quote!(#param_name: #rust_type))
} else {
Ok(quote!(#param_name: impl Into<PyObject>))
}
}
}
impl<'a> FromPyObject<'a> for Arguments {
fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
let posonlyargs: Vec<Parameter> = ob.getattr("posonlyargs")?.extract().unwrap_or_default();
let args: Vec<Parameter> = ob.getattr("args")?.extract().unwrap_or_default();
let vararg = if let Ok(va) = ob.getattr("vararg") {
if va.is_none() { None } else { Some(va.extract()?) }
} else { None };
let kwonlyargs: Vec<Parameter> = ob.getattr("kwonlyargs")?.extract().unwrap_or_default();
let kw_defaults = if let Ok(kw_def) = ob.getattr("kw_defaults") {
let defaults_list: Vec<Bound<PyAny>> = kw_def.extract().unwrap_or_default();
let mut processed_defaults = Vec::new();
for default in defaults_list {
if default.is_none() {
processed_defaults.push(None);
} else {
processed_defaults.push(Some(Box::new(default.extract()?)));
}
}
processed_defaults
} else {
Vec::new()
};
let kwarg = if let Ok(kw) = ob.getattr("kwarg") {
if kw.is_none() { None } else { Some(kw.extract()?) }
} else { None };
let defaults_raw: Vec<ExprType> = ob.getattr("defaults")?.extract().unwrap_or_default();
let defaults = defaults_raw.into_iter().map(Box::new).collect();
Ok(Self {
posonlyargs,
args,
vararg,
kwonlyargs,
kw_defaults,
kwarg,
defaults,
})
}
}
impl CodeGen for Arguments {
type Context = CodeGenContext;
type Options = PythonOptions;
type SymbolTable = SymbolTableScopes;
fn to_rust(
self,
ctx: Self::Context,
options: Self::Options,
symbols: Self::SymbolTable,
) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
let mut params = Vec::new();
for arg in self.posonlyargs {
let param = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
params.push(param);
}
let defaults_offset = self.args.len().saturating_sub(self.defaults.len());
for (i, arg) in self.args.into_iter().enumerate() {
if i >= defaults_offset {
let default_idx = i - defaults_offset;
let default_value = &self.defaults[default_idx];
let _default_rust = default_value.as_ref().clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
let param_name = quote::format_ident!("{}", arg.arg);
if let Some(annotation) = &arg.annotation {
let rust_type = annotation.as_ref().clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
params.push(quote!(#param_name: Option<#rust_type>));
} else {
params.push(quote!(#param_name: Option<impl Into<PyObject>>));
}
} else {
let param = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
params.push(param);
}
}
if let Some(vararg) = self.vararg {
let vararg_name = quote::format_ident!("{}", vararg.arg);
params.push(quote!(#vararg_name: impl IntoIterator<Item = impl Into<PyObject>>));
}
for (i, arg) in self.kwonlyargs.into_iter().enumerate() {
let param_name = quote::format_ident!("{}", arg.arg);
let has_default = i < self.kw_defaults.len() && self.kw_defaults[i].is_some();
if let Some(annotation) = &arg.annotation {
let rust_type = annotation.as_ref().clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
if has_default {
params.push(quote!(#param_name: Option<#rust_type>));
} else {
params.push(quote!(#param_name: #rust_type));
}
} else {
if has_default {
params.push(quote!(#param_name: Option<impl Into<PyObject>>));
} else {
params.push(quote!(#param_name: impl Into<PyObject>));
}
}
}
if let Some(kwarg) = self.kwarg {
let kwarg_name = quote::format_ident!("{}", kwarg.arg);
params.push(quote!(#kwarg_name: impl IntoIterator<Item = (impl AsRef<str>, impl Into<PyObject>)>));
}
Ok(quote!(#(#params),*))
}
}
impl<'a> FromPyObject<'a> for CallArguments {
fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
let args: Vec<ExprType> = ob.getattr("args")?.extract().unwrap_or_default();
let keywords: Vec<crate::Keyword> = ob.getattr("keywords")?.extract().unwrap_or_default();
Ok(Self { args, keywords })
}
}
impl CodeGen for CallArguments {
type Context = CodeGenContext;
type Options = PythonOptions;
type SymbolTable = SymbolTableScopes;
fn to_rust(
self,
ctx: Self::Context,
options: Self::Options,
symbols: Self::SymbolTable,
) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
let mut all_args = Vec::new();
for arg in self.args {
let rust_arg = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
all_args.push(rust_arg);
}
for keyword in self.keywords {
let rust_kw = keyword.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
all_args.push(rust_kw);
}
Ok(quote!(#(#all_args),*))
}
}
impl Node for Argument {
fn lineno(&self) -> Option<usize> { self.lineno }
fn col_offset(&self) -> Option<usize> { self.col_offset }
fn end_lineno(&self) -> Option<usize> { self.end_lineno }
fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
}
impl Node for Parameter {
fn lineno(&self) -> Option<usize> { self.lineno }
fn col_offset(&self) -> Option<usize> { self.col_offset }
fn end_lineno(&self) -> Option<usize> { self.end_lineno }
fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{parse, CodeGenContext, ExprType, PythonOptions, SymbolTableScopes};
use test_log::test;
#[test]
fn test_simple_function_call() {
let code = "func(1, 2, 3)";
let result = parse(code, "test.py").unwrap();
let options = PythonOptions::default();
let symbols = SymbolTableScopes::new();
let _rust_code = result.to_rust(
CodeGenContext::Module("test".to_string()),
options,
symbols,
).unwrap();
}
#[test]
fn test_keyword_arguments() {
let code = "func(a=1, b=2)";
let result = parse(code, "test.py").unwrap();
let options = PythonOptions::default();
let symbols = SymbolTableScopes::new();
let _rust_code = result.to_rust(
CodeGenContext::Module("test".to_string()),
options,
symbols,
).unwrap();
}
#[test]
fn test_mixed_arguments() {
let code = "func(1, 2, c=3, d=4)";
let result = parse(code, "test.py").unwrap();
let options = PythonOptions::default();
let symbols = SymbolTableScopes::new();
let _rust_code = result.to_rust(
CodeGenContext::Module("test".to_string()),
options,
symbols,
).unwrap();
}
#[test]
fn test_function_with_defaults() {
let code = r#"
def func(a, b=2, c=3):
pass
"#;
let result = parse(code, "test.py").unwrap();
let options = PythonOptions::default();
let symbols = SymbolTableScopes::new();
let _rust_code = result.to_rust(
CodeGenContext::Module("test".to_string()),
options,
symbols,
).unwrap();
}
#[test]
fn test_function_with_varargs() {
let code = r#"
def func(a, *args):
pass
"#;
let result = parse(code, "test.py").unwrap();
let options = PythonOptions::default();
let symbols = SymbolTableScopes::new();
let _rust_code = result.to_rust(
CodeGenContext::Module("test".to_string()),
options,
symbols,
).unwrap();
}
#[test]
fn test_function_with_kwargs() {
let code = r#"
def func(a, **kwargs):
pass
"#;
let result = parse(code, "test.py").unwrap();
let options = PythonOptions::default();
let symbols = SymbolTableScopes::new();
let _rust_code = result.to_rust(
CodeGenContext::Module("test".to_string()),
options,
symbols,
).unwrap();
}
#[test]
fn test_complex_function_signature() {
let code = r#"
def func(a, b=2, *args, c, d=4, **kwargs):
pass
"#;
let result = parse(code, "test.py").unwrap();
let options = PythonOptions::default();
let symbols = SymbolTableScopes::new();
let _rust_code = result.to_rust(
CodeGenContext::Module("test".to_string()),
options,
symbols,
).unwrap();
}
#[test]
fn test_keyword_only_arguments() {
let code = r#"
def func(a, *, b, c=3):
pass
"#;
let result = parse(code, "test.py").unwrap();
let options = PythonOptions::default();
let symbols = SymbolTableScopes::new();
let _rust_code = result.to_rust(
CodeGenContext::Module("test".to_string()),
options,
symbols,
).unwrap();
}
#[test]
fn test_argument_unpacking_call() {
let code = "func(*args, **kwargs)";
let result = parse(code, "test.py");
match result {
Ok(ast) => {
let options = PythonOptions::default();
let symbols = SymbolTableScopes::new();
let rust_code = ast.to_rust(
CodeGenContext::Module("test".to_string()),
options,
symbols,
);
match rust_code {
Ok(_code) => { },
Err(_e) => { },
}
}
Err(_e) => { },
}
}
#[test]
fn test_arg_with_constant() {
use litrs::Literal;
let literal = Literal::parse("42").unwrap().into_owned();
let constant = crate::Constant(Some(literal));
let arg: Arg = ExprType::Constant(constant);
let options = PythonOptions::default();
let symbols = SymbolTableScopes::new();
let rust_code = arg.to_rust(
CodeGenContext::Module("test".to_string()),
options,
symbols,
).unwrap();
assert!(rust_code.to_string().contains("42"));
}
#[test]
fn test_arg_with_name() {
let name_expr = ExprType::Name(crate::Name {
id: "variable".to_string(),
});
let arg: Arg = name_expr;
let options = PythonOptions::default();
let symbols = SymbolTableScopes::new();
let rust_code = arg.to_rust(
CodeGenContext::Module("test".to_string()),
options,
symbols,
).unwrap();
assert!(rust_code.to_string().contains("variable"));
}
}