use crate::{
attributes::{self, take_pyo3_options},
deprecations::Deprecations,
pyfunction::{impl_wrap_pyfunction, PyFunctionOptions},
};
use crate::{
attributes::{is_attribute_ident, take_attributes, NameAttribute},
deprecations::Deprecation,
};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{
ext::IdentExt,
parse::{Parse, ParseStream},
spanned::Spanned,
token::Comma,
Ident, Path, Result,
};
pub struct PyModuleOptions {
name: Option<syn::Ident>,
deprecations: Deprecations,
}
impl PyModuleOptions {
pub fn from_pymodule_arg_and_attrs(
deprecated_pymodule_name_arg: Option<syn::Ident>,
attrs: &mut Vec<syn::Attribute>,
) -> Result<Self> {
let mut deprecations = Deprecations::new();
if let Some(name) = &deprecated_pymodule_name_arg {
deprecations.push(Deprecation::PyModuleNameArgument, name.span());
}
let mut options: PyModuleOptions = PyModuleOptions {
name: deprecated_pymodule_name_arg,
deprecations,
};
for option in take_pyo3_options(attrs)? {
match option {
PyModulePyO3Option::Name(name) => options.set_name(name.0)?,
}
}
Ok(options)
}
fn set_name(&mut self, name: syn::Ident) -> Result<()> {
ensure_spanned!(
self.name.is_none(),
name.span() => "`name` may only be specified once"
);
self.name = Some(name);
Ok(())
}
}
pub fn py_init(fnname: &Ident, options: PyModuleOptions, doc: syn::LitStr) -> TokenStream {
let name = options.name.unwrap_or_else(|| fnname.unraw());
let deprecations = options.deprecations;
let cb_name = Ident::new(&format!("PyInit_{}", name), Span::call_site());
assert!(doc.value().ends_with('\0'));
quote! {
#[no_mangle]
#[allow(non_snake_case)]
pub unsafe extern "C" fn #cb_name() -> *mut pyo3::ffi::PyObject {
use pyo3::derive_utils::ModuleDef;
static NAME: &str = concat!(stringify!(#name), "\0");
static DOC: &str = #doc;
static MODULE_DEF: ModuleDef = unsafe { ModuleDef::new(NAME, DOC) };
#deprecations
pyo3::callback::handle_panic(|_py| { MODULE_DEF.make_module(_py, #fnname) })
}
}
}
pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> {
let mut stmts: Vec<syn::Stmt> = Vec::new();
for mut stmt in func.block.stmts.drain(..) {
if let syn::Stmt::Item(syn::Item::Fn(func)) = &mut stmt {
if let Some(pyfn_args) = get_pyfn_attr(&mut func.attrs)? {
let module_name = pyfn_args.modname;
let (ident, wrapped_function) = impl_wrap_pyfunction(func, pyfn_args.options)?;
let statements: Vec<syn::Stmt> = syn::parse_quote! {
#wrapped_function
#module_name.add_function(#ident(#module_name)?)?;
};
stmts.extend(statements);
}
};
stmts.push(stmt);
}
func.block.stmts = stmts;
Ok(())
}
pub struct PyFnArgs {
modname: Path,
options: PyFunctionOptions,
}
impl Parse for PyFnArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let modname = input.parse().map_err(
|e| err_spanned!(e.span() => "expected module as first argument to #[pyfn()]"),
)?;
if input.is_empty() {
return Ok(Self {
modname,
options: Default::default(),
});
}
let _: Comma = input.parse()?;
let mut deprecated_name_argument = None;
if let Ok(lit_str) = input.parse::<syn::LitStr>() {
deprecated_name_argument = Some(lit_str);
if !input.is_empty() {
let _: Comma = input.parse()?;
}
}
let mut options: PyFunctionOptions = input.parse()?;
if let Some(lit_str) = deprecated_name_argument {
options.set_name(NameAttribute(lit_str.parse()?))?;
options
.deprecations
.push(Deprecation::PyfnNameArgument, lit_str.span());
}
Ok(Self { modname, options })
}
}
fn get_pyfn_attr(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<PyFnArgs>> {
let mut pyfn_args: Option<PyFnArgs> = None;
take_attributes(attrs, |attr| {
if is_attribute_ident(attr, "pyfn") {
ensure_spanned!(
pyfn_args.is_none(),
attr.span() => "`#[pyfn] may only be specified once"
);
pyfn_args = Some(attr.parse_args()?);
Ok(true)
} else {
Ok(false)
}
})?;
if let Some(pyfn_args) = &mut pyfn_args {
pyfn_args.options.take_pyo3_options(attrs)?;
}
Ok(pyfn_args)
}
enum PyModulePyO3Option {
Name(NameAttribute),
}
impl Parse for PyModulePyO3Option {
fn parse(input: ParseStream) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name) {
input.parse().map(PyModulePyO3Option::Name)
} else {
Err(lookahead.error())
}
}
}