use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{
Attribute, ImplItem, ItemEnum, ItemImpl, ItemStruct, LitBool, LitStr, Token,
parse::{Parse, ParseStream},
parse_macro_input,
};
struct MacroArgs {
feature: String,
stub_gen: Option<String>,
pyclass_args: Option<TokenStream2>,
}
impl Default for MacroArgs {
fn default() -> Self {
Self {
feature: "python".to_string(),
stub_gen: None,
pyclass_args: None,
}
}
}
impl Parse for MacroArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut args = MacroArgs::default();
while !input.is_empty() {
let ident: syn::Ident = input.parse()?;
match ident.to_string().as_str() {
"feature" => {
input.parse::<Token![=]>()?;
args.feature = input.parse::<LitStr>()?.value();
}
"stub_gen" => {
input.parse::<Token![=]>()?;
if input.peek(LitBool) {
let b: LitBool = input.parse()?;
args.stub_gen = b.value().then(|| "python".to_string());
} else {
args.stub_gen = Some(input.parse::<LitStr>()?.value());
}
}
"pyclass_args" => {
let inner;
syn::parenthesized!(inner in input);
args.pyclass_args = Some(inner.parse::<TokenStream2>()?);
}
other => {
return Err(syn::Error::new(
ident.span(),
format!(
"unknown argument `{other}`; \
expected `feature`, `stub_gen`, or `pyclass_args`"
),
));
}
}
let _ = input.parse::<Token![,]>();
}
Ok(args)
}
}
fn is_pyo3_related(attr: &Attribute) -> bool {
attr.path()
.segments
.first()
.map(|s| {
matches!(
s.ident.to_string().as_str(),
"pyo3" | "pyclass" | "pymethods" | "pyfunction" | "pymodule"
)
})
.unwrap_or(false)
}
fn is_sentinel(attr: &Attribute) -> bool {
attr.path().is_ident("py_only") || attr.path().is_ident("py_attrs")
}
fn is_gen_stub(attr: &Attribute) -> bool {
attr.path().is_ident("gen_stub")
}
fn strip_gen_stub_from_item(item: &mut ImplItem) {
match item {
ImplItem::Fn(f) => f.attrs.retain(|a| !is_gen_stub(a)),
ImplItem::Const(c) => c.attrs.retain(|a| !is_gen_stub(a)),
ImplItem::Type(t) => t.attrs.retain(|a| !is_gen_stub(a)),
ImplItem::Macro(m) => m.attrs.retain(|a| !is_gen_stub(a)),
_ => {}
}
}
fn strip_gen_stub_from_fields(fields: &mut syn::Fields) {
let iter: Box<dyn Iterator<Item = &mut syn::Field>> = match fields {
syn::Fields::Named(f) => Box::new(f.named.iter_mut()),
syn::Fields::Unnamed(f) => Box::new(f.unnamed.iter_mut()),
syn::Fields::Unit => return,
};
for field in iter {
field.attrs.retain(|a| !is_gen_stub(a));
}
}
fn strip_gen_stub_from_variants(
variants: &mut syn::punctuated::Punctuated<syn::Variant, Token![,]>,
) {
for variant in variants.iter_mut() {
variant.attrs.retain(|a| !is_gen_stub(a));
strip_gen_stub_from_fields(&mut variant.fields);
}
}
fn impl_item_attrs(item: &ImplItem) -> &[Attribute] {
match item {
ImplItem::Fn(f) => &f.attrs,
ImplItem::Const(c) => &c.attrs,
ImplItem::Type(t) => &t.attrs,
ImplItem::Macro(m) => &m.attrs,
_ => &[],
}
}
fn clear_impl_item_attrs(item: &mut ImplItem) {
match item {
ImplItem::Fn(f) => f.attrs.clear(),
ImplItem::Const(c) => c.attrs.clear(),
ImplItem::Type(t) => t.attrs.clear(),
ImplItem::Macro(m) => m.attrs.clear(),
_ => {}
}
}
fn strip_sentinels(item: &mut ImplItem) {
match item {
ImplItem::Fn(f) => f.attrs.retain(|a| !is_sentinel(a)),
ImplItem::Const(c) => c.attrs.retain(|a| !is_sentinel(a)),
ImplItem::Type(t) => t.attrs.retain(|a| !is_sentinel(a)),
ImplItem::Macro(m) => m.attrs.retain(|a| !is_sentinel(a)),
_ => {}
}
}
fn strip_pyo3_from_fields(fields: &mut syn::Fields) {
let iter: Box<dyn Iterator<Item = &mut syn::Field>> = match fields {
syn::Fields::Named(f) => Box::new(f.named.iter_mut()),
syn::Fields::Unnamed(f) => Box::new(f.unnamed.iter_mut()),
syn::Fields::Unit => return,
};
for field in iter {
field.attrs.retain(|a| !is_pyo3_related(a));
}
}
fn strip_pyo3_from_variants(variants: &mut syn::punctuated::Punctuated<syn::Variant, Token![,]>) {
for variant in variants.iter_mut() {
variant.attrs.retain(|a| !is_pyo3_related(a));
strip_pyo3_from_fields(&mut variant.fields);
}
}
fn is_simple_enum(item: &ItemEnum) -> bool {
item.variants
.iter()
.all(|v| matches!(v.fields, syn::Fields::Unit))
}
enum StubKind {
Struct,
SimpleEnum,
ComplexEnum,
Methods,
Function,
}
fn stub_attr(stub_gen: &Option<String>, kind: StubKind) -> TokenStream2 {
let Some(sg) = stub_gen else {
return quote! {};
};
match kind {
StubKind::Struct => quote! {
#[cfg_attr(feature = #sg, ::pyo3_stub_gen::derive::gen_stub_pyclass)]
},
StubKind::SimpleEnum => quote! {
#[cfg_attr(feature = #sg, ::pyo3_stub_gen::derive::gen_stub_pyclass_enum)]
},
StubKind::ComplexEnum => quote! {
#[cfg_attr(feature = #sg, ::pyo3_stub_gen::derive::gen_stub_pyclass_complex_enum)]
},
StubKind::Methods => quote! {
#[cfg_attr(feature = #sg, ::pyo3_stub_gen::derive::gen_stub_pymethods)]
},
StubKind::Function => quote! {
#[cfg_attr(feature = #sg, ::pyo3_stub_gen::derive::gen_stub_pyfunction)]
},
}
}
#[proc_macro_attribute]
pub fn py_compat_struct(attr: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as MacroArgs);
let feature = &args.feature;
let input_struct = parse_macro_input!(input as ItemStruct);
let mut py_struct = input_struct.clone();
let mut plain_struct = input_struct;
plain_struct
.attrs
.retain(|a| !is_pyo3_related(a) && !is_gen_stub(a));
strip_pyo3_from_fields(&mut plain_struct.fields);
strip_gen_stub_from_fields(&mut plain_struct.fields);
if args.stub_gen.is_none() {
py_struct.attrs.retain(|a| !is_gen_stub(a));
strip_gen_stub_from_fields(&mut py_struct.fields);
}
let stub = stub_attr(&args.stub_gen, StubKind::Struct);
let pyclass_inner = args
.pyclass_args
.as_ref()
.map_or(quote! {}, |a| quote! { (#a) });
quote! {
#stub
#[cfg(feature = #feature)]
#[::pyo3::pyclass #pyclass_inner]
#py_struct
#[cfg(not(feature = #feature))]
#plain_struct
}
.into()
}
#[proc_macro_attribute]
pub fn py_compat_enum(attr: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as MacroArgs);
let feature = &args.feature;
let input_enum = parse_macro_input!(input as ItemEnum);
let stub_kind = if is_simple_enum(&input_enum) {
StubKind::SimpleEnum
} else {
StubKind::ComplexEnum
};
let mut py_enum = input_enum.clone();
let mut plain_enum = input_enum;
plain_enum
.attrs
.retain(|a| !is_pyo3_related(a) && !is_gen_stub(a));
strip_pyo3_from_variants(&mut plain_enum.variants);
strip_gen_stub_from_variants(&mut plain_enum.variants);
if args.stub_gen.is_none() {
py_enum.attrs.retain(|a| !is_gen_stub(a));
strip_gen_stub_from_variants(&mut py_enum.variants);
}
let stub = stub_attr(&args.stub_gen, stub_kind);
let pyclass_inner = args
.pyclass_args
.as_ref()
.map_or(quote! {}, |a| quote! { (#a) });
quote! {
#stub
#[cfg(feature = #feature)]
#[::pyo3::pyclass #pyclass_inner]
#py_enum
#[cfg(not(feature = #feature))]
#plain_enum
}
.into()
}
#[proc_macro_attribute]
pub fn py_compat_methods(attr: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as MacroArgs);
let feature = &args.feature;
let input_impl = parse_macro_input!(input as ItemImpl);
let self_ty = &input_impl.self_ty;
let (impl_generics, ty_generics, where_clause) = input_impl.generics.split_for_impl();
let pass_through_attrs: Vec<_> = input_impl
.attrs
.iter()
.filter(|a| !is_pyo3_related(a))
.collect();
let stub_gen_disabled = args.stub_gen.is_none();
let mut py_items = Vec::<TokenStream2>::new();
let mut plain_items = Vec::<TokenStream2>::new();
for item in &input_impl.items {
let attrs = impl_item_attrs(item);
let is_py_only = attrs.iter().any(|a| a.path().is_ident("py_only"));
let is_py_attrs = attrs.iter().any(|a| a.path().is_ident("py_attrs"));
if is_py_only && is_py_attrs {
return syn::Error::new_spanned(
quote! { #item },
"`#[py_only]` and `#[py_attrs]` cannot both appear on the same item",
)
.to_compile_error()
.into();
}
let mut clean = item.clone();
strip_sentinels(&mut clean);
if is_py_only {
if stub_gen_disabled {
strip_gen_stub_from_item(&mut clean);
}
py_items.push(quote! { #clean });
} else if is_py_attrs {
let mut stripped = clean.clone();
clear_impl_item_attrs(&mut stripped);
if stub_gen_disabled {
strip_gen_stub_from_item(&mut clean);
}
py_items.push(quote! { #clean });
plain_items.push(quote! { #stripped });
} else {
let mut py_clean = clean.clone();
let mut plain_clean = clean;
if stub_gen_disabled {
strip_gen_stub_from_item(&mut py_clean);
}
strip_gen_stub_from_item(&mut plain_clean);
py_items.push(quote! { #py_clean });
plain_items.push(quote! { #plain_clean });
}
}
let stub = stub_attr(&args.stub_gen, StubKind::Methods);
quote! {
#stub
#[cfg(feature = #feature)]
#[::pyo3::pymethods]
#(#pass_through_attrs)*
impl #impl_generics #self_ty #ty_generics #where_clause {
#(#py_items)*
}
#[cfg(not(feature = #feature))]
#(#pass_through_attrs)*
impl #impl_generics #self_ty #ty_generics #where_clause {
#(#plain_items)*
}
}
.into()
}
#[proc_macro_attribute]
pub fn py_compat_fn(attr: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as MacroArgs);
let feature = &args.feature;
let input_fn = parse_macro_input!(input as syn::ItemFn);
let py_fn = input_fn.clone();
let mut plain_fn = input_fn;
plain_fn.attrs.retain(|a| !is_pyo3_related(a));
let stub = stub_attr(&args.stub_gen, StubKind::Function);
quote! {
#stub
#[cfg(feature = #feature)]
#[::pyo3::pyfunction]
#py_fn
#[cfg(not(feature = #feature))]
#plain_fn
}
.into()
}