use std::fmt::Debug;
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{format_ident, ToTokens};
use syn::{
parse::{Parse, ParseStream},
parse2, parse_quote,
token::Colon,
Attribute, Ident, ItemFn, Signature, Stmt, Token,
};
pub fn impl_pyo3test(_attr: TokenStream2, input: TokenStream2) -> TokenStream2 {
let testcase: Pyo3TestCase = match parse2::<ItemFn>(input).and_then(|itemfn| itemfn.try_into())
{
Ok(testcase) => testcase,
Err(e) => return e.into_compile_error(),
};
wrap_testcase(testcase)
}
struct Pyo3TestCase {
pyo3imports: Vec<Pyo3Import>,
signature: Signature,
statements: Vec<Stmt>,
otherattributes: Vec<Attribute>,
}
impl TryFrom<ItemFn> for Pyo3TestCase {
type Error = syn::Error;
fn try_from(testcase: ItemFn) -> syn::Result<Pyo3TestCase> {
let mut pyo3imports = Vec::<Pyo3Import>::new();
let mut otherattributes = Vec::<Attribute>::new();
for attr in testcase.attrs {
if attr.path().is_ident("pyo3import") {
pyo3imports.push(attr.parse_args()?);
} else {
otherattributes.push(attr);
};
}
Ok(Pyo3TestCase {
pyo3imports,
signature: testcase.sig,
statements: testcase.block.stmts,
otherattributes,
})
}
}
#[derive(Debug, PartialEq)]
struct Pyo3Import {
o3_moduleident: Ident,
py_modulename: String,
py_functionname: Option<String>,
}
impl Parse for Pyo3Import {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let o3_moduleident;
if input.peek2(Token![:]) {
o3_moduleident = input.parse()?;
let _: Colon = input.parse()?;
} else {
return Err(input.error("invalid import statement: expected a colon (':') after this"));
}
let firstkeyword: PythonImportKeyword = input.parse()?;
let py_modulename = input.parse::<Ident>()?.to_string();
let py_functionname = match firstkeyword {
PythonImportKeyword::from => {
let _import: PythonImportKeyword = input.parse()?;
Some(input.parse::<Ident>()?.to_string())
}
PythonImportKeyword::import => None,
};
Ok(Pyo3Import {
o3_moduleident,
py_modulename,
py_functionname,
})
}
}
#[allow(non_camel_case_types)] #[derive(Debug, PartialEq)]
enum PythonImportKeyword {
from,
import,
}
impl Parse for PythonImportKeyword {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let keyword = input.parse::<Ident>()?;
match keyword.to_string().as_str() {
"from" => Ok(PythonImportKeyword::from),
"import" => Ok(PythonImportKeyword::import),
_ => Err(syn::Error::new(
keyword.span(),
"invalid import statement: expect 'from' or 'import' here",
)),
}
}
}
#[allow(non_snake_case)] fn wrap_testcase(mut testcase: Pyo3TestCase) -> TokenStream2 {
let mut o3_moduleidents = Vec::<Ident>::new(); let mut o3_pymoduledefidents = Vec::<Ident>::new(); let mut o3_pymoduleidents = Vec::<Ident>::new(); let mut py_moduleidents = Vec::<Ident>::new(); let mut py_modulenames = Vec::<String>::new(); let mut py_ModuleNotFoundErrormsgs = Vec::<String>::new(); let mut py_functionidents = Vec::<Ident>::new(); let mut py_macroidents = Vec::<Ident>::new(); let mut py_moduleswithfnsidents = Vec::<Ident>::new(); let mut py_functionnames = Vec::<String>::new(); let mut py_AttributeErrormsgs = Vec::<String>::new();
for pyo3import in testcase.pyo3imports {
let py_modulename = pyo3import.py_modulename;
if let Some(py_functionname) = pyo3import.py_functionname {
py_AttributeErrormsgs
.push("Failed to get ".to_string() + &py_functionname + " function");
py_functionidents.push(Ident::new(&py_functionname, Span::call_site()));
py_macroidents.push(Ident::new(&py_functionname, Span::call_site()));
py_moduleswithfnsidents.push(Ident::new(&py_modulename, Span::call_site()));
py_functionnames.push(py_functionname);
};
py_ModuleNotFoundErrormsgs.push("Failed to import ".to_string() + &py_modulename);
py_moduleidents.push(Ident::new(&py_modulename, Span::call_site()));
py_modulenames.push(py_modulename);
o3_pymoduledefidents.push(format_ident!("{}_pymoduledef", pyo3import.o3_moduleident));
o3_pymoduleidents.push(format_ident!("{}_pymodule", pyo3import.o3_moduleident));
o3_moduleidents.push(pyo3import.o3_moduleident);
}
let testfn_signature = testcase.signature;
let testfn_statements = testcase.statements;
let mut testfn: ItemFn = parse_quote!(
#[test]
#testfn_signature {
use pyo3::types::PyDict;
Python::initialize();
Python::attach(|py| {
let sys = PyModule::import(py, "sys").unwrap();
let sys_modules: Bound<'_, PyDict> =
sys.getattr("modules").unwrap().cast_into().unwrap();
#(
let #o3_pymoduledefidents = &#o3_moduleidents::_PYO3_DEF;
let #o3_pymoduleidents = #o3_pymoduledefidents
.make_module(py)
.unwrap();
let #o3_pymoduleidents = #o3_pymoduleidents.bind(py);
sys_modules
.set_item(#py_modulenames, #o3_pymoduleidents)
.expect(#py_ModuleNotFoundErrormsgs);
let #py_moduleidents = sys_modules.get_item(#py_modulenames).unwrap().unwrap();
)*
#(
let #py_functionidents = #py_moduleswithfnsidents
.getattr(#py_functionnames)
.expect(#py_AttributeErrormsgs);
macro_rules! #py_macroidents {
($($arg:tt),+) => {
#py_functionidents
.call1(($($arg,)+))
.unwrap()
.extract()
.unwrap()
};
(*$args:ident) => {
#py_functionidents
.call1($args)
.unwrap()
.extract()
.unwrap()
};
() => {
#py_functionidents
.call0()
.unwrap()
.extract()
.unwrap()
};
};
)*
#(#testfn_statements)*
});
}
);
testfn.attrs.append(&mut testcase.otherattributes);
testfn.into_token_stream()
}
#[allow(clippy::non_minimal_cfg)]
#[cfg(all(test))]
mod tests {
use quote::quote;
use super::*;
#[test]
fn test_other_attribute() {
let testcase: TokenStream2 = quote! {
#[pyo3import(py_fizzbuzzo3: from fizzbuzzo3 import fizzbuzz)]
#[anotherattribute]
#[pyo3import(foo_o3: from pyfoo import pybar)]
fn test_fizzbuzz() {
assert!(true)
}
};
let expected: TokenStream2 = quote! {
#[test]
#[anotherattribute]
fn test_fizzbuzz() {
use pyo3::types::PyDict;
Python::initialize();
Python::attach(|py| {
let sys = PyModule::import(py, "sys").unwrap();
let sys_modules: Bound<'_, PyDict> =
sys.getattr("modules").unwrap().cast_into().unwrap();
let py_fizzbuzzo3_pymoduledef = &py_fizzbuzzo3::_PYO3_DEF;
let py_fizzbuzzo3_pymodule = py_fizzbuzzo3_pymoduledef
.make_module(py)
.unwrap();
let py_fizzbuzzo3_pymodule = py_fizzbuzzo3_pymodule.bind(py);
sys_modules
.set_item("fizzbuzzo3", py_fizzbuzzo3_pymodule)
.expect("Failed to import fizzbuzzo3");
let fizzbuzzo3 = sys_modules.get_item("fizzbuzzo3").unwrap().unwrap();
let foo_o3_pymoduledef = &foo_o3::_PYO3_DEF;
let foo_o3_pymodule = foo_o3_pymoduledef
.make_module(py)
.unwrap();
let foo_o3_pymodule = foo_o3_pymodule.bind(py);
sys_modules
.set_item("pyfoo", foo_o3_pymodule)
.expect("Failed to import pyfoo");
let pyfoo = sys_modules.get_item("pyfoo").unwrap().unwrap();
let fizzbuzz = fizzbuzzo3
.getattr("fizzbuzz")
.expect("Failed to get fizzbuzz function");
macro_rules! fizzbuzz {
($($arg:tt),+) => {
fizzbuzz
.call1(($($arg,)+))
.unwrap()
.extract()
.unwrap()
};
(*$args:ident) => {
fizzbuzz
.call1($args)
.unwrap()
.extract()
.unwrap()
};
() => {
fizzbuzz
.call0()
.unwrap()
.extract()
.unwrap()
};
};
let pybar = pyfoo
.getattr("pybar")
.expect("Failed to get pybar function");
macro_rules! pybar {
($($arg:tt),+) => {
pybar
.call1(($($arg,)+))
.unwrap()
.extract()
.unwrap()
};
(*$args:ident) => {
pybar
.call1($args)
.unwrap()
.extract()
.unwrap()
};
() => {
pybar
.call0()
.unwrap()
.extract()
.unwrap()
};
};
assert!(true)
});
}
};
let output: TokenStream2 = impl_pyo3test(quote! {}, testcase);
assert_eq!(output.to_string(), expected.to_string());
}
}