use super::Diagnostic;
use crate::util::{def_to_name, ItemIdent, ItemMeta};
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{quote, quote_spanned, ToTokens};
use std::collections::HashSet;
use syn::{parse_quote, spanned::Spanned, Attribute, AttributeArgs, Ident, Item, Meta, NestedMeta};
fn meta_to_vec(meta: Meta) -> Result<Vec<NestedMeta>, Meta> {
match meta {
Meta::Path(_) => Ok(Vec::new()),
Meta::List(list) => Ok(list.nested.into_iter().collect()),
Meta::NameValue(_) => Err(meta),
}
}
#[derive(Default)]
struct Module {
items: HashSet<ModuleItem>,
}
#[derive(PartialEq, Eq, Hash)]
enum ModuleItem {
Function { item_ident: Ident, py_name: String },
Class { item_ident: Ident, py_name: String },
}
impl Module {
fn add_item(&mut self, item: ModuleItem, span: Span) -> Result<(), Diagnostic> {
if self.items.insert(item) {
Ok(())
} else {
Err(Diagnostic::span_error(
span,
"Duplicate #[py*] attribute on pyimpl".to_owned(),
))
}
}
fn extract_function(ident: &Ident, meta: Meta) -> Result<ModuleItem, Diagnostic> {
let nesteds = meta_to_vec(meta).map_err(|meta| {
err_span!(
meta,
"#[pyfunction = \"...\"] cannot be a name/value, you probably meant \
#[pyfunction(name = \"...\")]",
)
})?;
let item_meta =
ItemMeta::from_nested_meta("pyfunction", &ident, &nesteds, ItemMeta::SIMPLE_NAMES)?;
Ok(ModuleItem::Function {
item_ident: ident.clone(),
py_name: item_meta.simple_name()?,
})
}
fn extract_class(ident: &Ident, meta: Meta) -> Result<ModuleItem, Diagnostic> {
let nesteds = meta_to_vec(meta).map_err(|meta| {
err_span!(
meta,
"#[pyclass = \"...\"] cannot be a name/value, you probably meant \
#[pyclass(name = \"...\")]",
)
})?;
let item_meta =
ItemMeta::from_nested_meta("pyclass", &ident, &nesteds, ItemMeta::SIMPLE_NAMES)?;
Ok(ModuleItem::Class {
item_ident: ident.clone(),
py_name: item_meta.simple_name()?,
})
}
fn extract_item_from_syn(
&mut self,
attrs: &mut Vec<Attribute>,
ident: &Ident,
) -> Result<(), Diagnostic> {
let mut attr_idxs = Vec::new();
for (i, meta) in attrs
.iter()
.filter_map(|attr| attr.parse_meta().ok())
.enumerate()
{
let meta_span = meta.span();
let name = match meta.path().get_ident() {
Some(name) => name,
None => continue,
};
let item = match name.to_string().as_str() {
"pyfunction" => {
attr_idxs.push(i);
Self::extract_function(ident, meta)?
}
"pyclass" => Self::extract_class(ident, meta)?,
_ => {
continue;
}
};
self.add_item(item, meta_span)?;
}
let mut i = 0;
let mut attr_idxs = &*attr_idxs;
attrs.retain(|_| {
let drop = attr_idxs.first().copied() == Some(i);
if drop {
attr_idxs = &attr_idxs[1..];
}
i += 1;
!drop
});
for (i, idx) in attr_idxs.iter().enumerate() {
attrs.remove(idx - i);
}
Ok(())
}
}
fn extract_module_items(mut items: Vec<ItemIdent>) -> Result<TokenStream2, Diagnostic> {
let mut diagnostics: Vec<Diagnostic> = Vec::new();
let mut module = Module::default();
for item in items.iter_mut() {
push_diag_result!(
diagnostics,
module.extract_item_from_syn(&mut item.attrs, item.ident),
);
}
let functions = module.items.into_iter().map(|item| match item {
ModuleItem::Function {
item_ident,
py_name,
} => {
let new_func = quote_spanned!(item_ident.span() => .new_function(#item_ident));
quote! {
vm.__module_set_attr(&module, #py_name, vm.ctx#new_func).unwrap();
}
}
ModuleItem::Class {
item_ident,
py_name,
} => {
let new_class = quote_spanned!(item_ident.span() => #item_ident::make_class(&vm.ctx));
quote! {
vm.__module_set_attr(&module, #py_name, #new_class).unwrap();
}
}
});
Diagnostic::from_vec(diagnostics)?;
Ok(quote! {
#(#functions)*
})
}
pub fn impl_pymodule(attr: AttributeArgs, item: Item) -> Result<TokenStream2, Diagnostic> {
let mut module = match item {
Item::Mod(m) => m,
other => bail_span!(other, "#[pymodule] can only be on a module declaration"),
};
let module_name = def_to_name(&module.ident, "pymodule", attr)?;
let (_, content) = match module.content.as_mut() {
Some(c) => c,
None => bail_span!(
module,
"#[pymodule] can only be on a module declaration with body"
),
};
let items = content
.iter_mut()
.filter_map(|item| match item {
Item::Fn(syn::ItemFn { attrs, sig, .. }) => Some(ItemIdent {
attrs,
ident: &sig.ident,
}),
Item::Struct(syn::ItemStruct { attrs, ident, .. }) => Some(ItemIdent { attrs, ident }),
Item::Enum(syn::ItemEnum { attrs, ident, .. }) => Some(ItemIdent { attrs, ident }),
_ => None,
})
.collect();
let extend_mod = extract_module_items(items)?;
content.extend(vec![
parse_quote! {
const MODULE_NAME: &str = #module_name;
},
parse_quote! {
pub(crate) fn extend_module(
vm: &::rustpython_vm::vm::VirtualMachine,
module: &::rustpython_vm::pyobject::PyObjectRef,
) {
#extend_mod
}
},
parse_quote! {
#[allow(dead_code)]
pub(crate) fn make_module(
vm: &::rustpython_vm::vm::VirtualMachine
) -> ::rustpython_vm::pyobject::PyObjectRef {
let module = vm.new_module(MODULE_NAME, vm.ctx.new_dict());
extend_module(vm, &module);
module
}
},
]);
Ok(module.into_token_stream())
}