use std::fmt;
use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{format_ident, quote};
use syn::{
ext::IdentExt,
parse::{Parse, ParseStream},
parse_macro_input, parse_quote,
spanned::Spanned,
token,
};
#[proc_macro_attribute]
pub fn sealed(args: TokenStream, input: TokenStream) -> TokenStream {
match parse_macro_input!(input) {
syn::Item::Impl(item_impl) => parse_sealed_impl(&item_impl),
syn::Item::Trait(item_trait) => {
Ok(parse_sealed_trait(item_trait, parse_macro_input!(args)))
}
_ => Err(syn::Error::new(Span::call_site(), "expected impl or trait")),
}
.unwrap_or_else(|e| e.to_compile_error())
.into()
}
fn parse_sealed_trait(mut item_trait: syn::ItemTrait, args: TraitArguments) -> TokenStream2 {
let trait_ident = &item_trait.ident.unraw();
let trait_generics = &item_trait.generics;
let trait_supertraits = &item_trait.supertraits;
let seal = seal_name(trait_ident);
let vis = &args.visibility;
let (_, ty_generics, where_clause) = trait_generics.split_for_impl();
let mod_code = if args.erased {
let lifetimes = trait_generics.lifetimes();
let const_params = trait_generics.const_params();
let type_params =
trait_generics
.type_params()
.map(|syn::TypeParam { ident, .. }| -> syn::TypeParam {
parse_quote!( #ident : ?Sized )
});
quote! {
pub trait Sealed< #(#lifetimes ,)* #(#type_params ,)* #(#const_params ,)* > {}
}
} else {
quote! {
use super::*;
pub trait Sealed #trait_generics : #trait_supertraits #where_clause {}
}
};
item_trait
.supertraits
.push(parse_quote!( #seal::Sealed #ty_generics ));
quote! {
#[automatically_derived]
#vis mod #seal {
#mod_code
}
#item_trait
}
}
fn parse_sealed_impl(item_impl: &syn::ItemImpl) -> syn::Result<TokenStream2> {
let impl_trait = item_impl
.trait_
.as_ref()
.ok_or_else(|| syn::Error::new_spanned(item_impl, "missing implementation trait"))?;
let mut sealed_path = impl_trait.1.segments.clone();
let syn::PathSegment { ident, arguments } = sealed_path.pop().unwrap().into_value();
let seal = seal_name(ident.unraw());
sealed_path.push(parse_quote!( #seal ));
sealed_path.push(parse_quote!(Sealed));
let self_type = &item_impl.self_ty;
let (trait_generics, _, where_clauses) = item_impl.generics.split_for_impl();
Ok(quote! {
#[automatically_derived]
impl #trait_generics #sealed_path #arguments for #self_type #where_clauses {}
#item_impl
})
}
fn to_snake_case(s: &'_ str) -> String {
let mut ret = String::with_capacity(s.len());
let mut first = true;
s.bytes().for_each(|c| {
if c.is_ascii_uppercase() {
if !first {
ret.push('_');
}
ret.push(c.to_ascii_lowercase() as char);
} else {
ret.push(c as char);
}
first = false;
});
ret
}
fn seal_name<D: fmt::Display>(seal: D) -> syn::Ident {
format_ident!("__seal_{}", to_snake_case(&seal.to_string()))
}
struct TraitArguments {
erased: bool,
visibility: syn::Visibility,
}
impl Default for TraitArguments {
fn default() -> Self {
Self {
erased: false,
visibility: syn::Visibility::Inherited,
}
}
}
impl Parse for TraitArguments {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let mut out = Self::default();
while !input.is_empty() {
let ident = syn::Ident::parse_any(&input.fork())?;
match ident.to_string().as_str() {
"erase" => {
syn::Ident::parse_any(input)?;
out.erased = true;
}
"pub" => {
out.visibility = input.parse()?;
if matches!(out.visibility, syn::Visibility::Public(_)) {
return Err(syn::Error::new(
out.visibility.span(),
"`pub` visibility breaks the seal as allows to use \
it outside its crate.\n\
Consider tightening the visibility (e.g. \
`pub(crate)`) if you actually need sealing.",
));
}
}
unknown => {
return Err(syn::Error::new(
ident.span(),
format!("unknown `{}` attribute argument", unknown),
))
}
}
if input
.lookahead1()
.peek(token::Comma)
.then(|| input.parse::<token::Comma>())
.transpose()?
.is_none()
&& !input.is_empty()
{
return Err(syn::Error::new(ident.span(), "expected followed by `,`"));
}
}
Ok(out)
}
}