use std::cmp::Ordering;
use derive_syn_parse::Parse;
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::{format_ident, quote};
use syn::{
parse,
parse::ParseStream,
token::{Brace, Where},
Block, ConstParam, Expr, GenericParam, Generics, Item, ItemEnum, ItemFn, ItemImpl, ItemStruct,
ItemType, ItemUnion, Token, TraitItem, TraitItemMethod, TraitItemType, TypeParam, Fields,
};
const INVALID_ITEM: &str = "guarded items need to support `where` clauses";
const INVALID_GUARD: &str = "guard must be an expression or polymorphic block";
const GUARD_FAILED: &str = "guard evaluated to false";
struct GuardItem {
ident: Ident,
decl: proc_macro2::TokenStream,
generics: Generics,
cont: proc_macro2::TokenStream,
}
#[derive(Parse)]
struct PolyBlock {
#[peek(Token![<])]
generics: Option<Generics>,
block: Block,
}
#[derive(Parse)]
enum Guard {
#[peek_with(|input: ParseStream<'_>| input.peek(Token![<]) || input.peek(Brace), name = "PolyBlock")]
PolyBlock(PolyBlock),
#[peek_with(|input: ParseStream<'_>| !(input.peek(Token![<]) || input.peek(Brace)), name = "Expr")]
Expr(Expr),
}
#[proc_macro_attribute]
pub fn guard(attr: TokenStream, stream: TokenStream) -> TokenStream {
let GuardItem {
ident,
decl,
generics,
cont,
} = GuardItem::from(stream);
let guard_ident = format_ident!("_{ident}_guard");
let where_ext = where_ext(&generics);
let (guard, generics, param_idents) = match Guard::from(attr) {
Guard::PolyBlock(PolyBlock {
generics: guard_generics,
block,
}) => {
let params = if let Some(guard_generics) = guard_generics {
merge_generic_params(generics, guard_generics)
} else {
generics.params.into_iter().collect::<Vec<GenericParam>>()
};
let generics = quote! {< #(#params),* >};
(quote!(#block), generics, param_idents(¶ms))
}
Guard::Expr(expr) => {
let params = generics.params.into_iter().collect::<Vec<GenericParam>>();
let generics = quote! {< #(#params),* >};
(quote!((#expr)), generics, param_idents(¶ms))
}
};
let tokens = quote! {
#decl #where_ext const_guards::Guard<{
#[allow(non_snake_case, private_in_public)] const fn #guard_ident #generics() -> bool {
if !#guard {
panic!(#GUARD_FAILED)
}
true
}
#guard_ident::<#param_idents>()
}>: const_guards::Protect #cont
};
TokenStream::from(tokens)
}
fn param_idents(params: &[GenericParam]) -> proc_macro2::TokenStream {
let idents = params
.iter()
.filter_map(param_ident)
.collect::<Vec<&Ident>>();
quote! {#(#idents),*}
}
fn where_ext(generics: &Generics) -> Option<proc_macro2::TokenStream> {
generics
.where_clause
.as_ref()
.map(|wc| {
if wc.predicates.trailing_punct() {
None
} else {
Some(quote! {,})
}
})
.or_else(|| {
let kw_where = Where(Span::call_site());
Some(Some(quote!(#kw_where)))
})
.unwrap()
}
fn merge_generic_params(left: Generics, right: Generics) -> Vec<GenericParam> {
let mut left_params = left.params.into_iter().collect::<Vec<GenericParam>>();
left_params.extend(right.params.into_iter().collect::<Vec<GenericParam>>());
let mut params = left_params
.into_iter()
.filter(|param| {
matches!(param, GenericParam::Type(_)) | matches!(param, GenericParam::Const(_))
})
.collect::<Vec<GenericParam>>();
params.sort_by(compare_params);
params.dedup_by(|left, right| compare_params(left, right).is_eq());
params
}
fn compare_params(left: &GenericParam, right: &GenericParam) -> Ordering {
match (param_ident(left), param_ident(right)) {
(Some(left), Some(right)) => left.cmp(right),
_ => unreachable!(),
}
}
fn param_ident(param: &GenericParam) -> Option<&Ident> {
match param {
syn::GenericParam::Type(TypeParam { ident, .. }) => Some(ident),
syn::GenericParam::Const(ConstParam { ident, .. }) => Some(ident),
_ => None,
}
}
impl From<TokenStream> for GuardItem {
fn from(stream: TokenStream) -> Self {
let item = parse::<Item>(stream).unwrap();
if let Item::Verbatim(stream) = item {
GuardItem::from(parse::<TraitItem>(TokenStream::from(stream)).unwrap())
} else {
GuardItem::from(item)
}
}
}
impl From<TokenStream> for Guard {
fn from(stream: TokenStream) -> Self {
parse::<Guard>(stream).expect(INVALID_GUARD)
}
}
impl From<Item> for GuardItem {
fn from(item: Item) -> Self {
let (decl, ident, generics, cont) = match item {
Item::Enum(ItemEnum {
attrs,
vis,
enum_token,
ident,
generics,
variants,
..
}) => {
let clause = &generics.where_clause;
(
quote! {#(#attrs)* #vis #enum_token #ident #generics #clause},
ident,
generics,
quote! {{ #variants }},
)
}
Item::Fn(ItemFn {
attrs,
vis,
sig,
block,
}) => (
quote! {#(#attrs)* #vis #sig},
sig.ident,
sig.generics,
quote! {#block},
),
Item::Impl(ItemImpl {
attrs,
defaultness,
unsafety,
impl_token,
generics,
trait_,
self_ty,
items,
..
}) => {
let trait_ = trait_.map(|(bang, path, kw_for)| quote! { #bang #path #kw_for});
let clause = &generics.where_clause;
(
quote! {#(#attrs)* #defaultness #unsafety #impl_token #generics #trait_ #self_ty #clause},
Ident::new("impl", Span::call_site()),
generics,
quote! {{#(#items)*}},
)
}
Item::Struct(ItemStruct {
attrs,
vis,
struct_token,
ident,
generics,
fields,
semi_token,
}) => {
let clause = &generics.where_clause;
let (a, b) = {
if matches!(fields, Fields::Named(_)) {
(quote!{#clause}, Some(quote!{#fields}))
} else {
(quote!{#fields #clause}, None)
}
};
(
quote! {#(#attrs)* #vis #struct_token #ident #generics #a},
ident,
generics,
quote! {#b #semi_token},
)
}
Item::Type(ItemType {
attrs,
vis,
type_token,
ident,
generics,
eq_token,
ty,
semi_token,
}) => {
let clause = &generics.where_clause;
(
quote! {#(#attrs)* #vis #type_token #ident #generics #clause},
ident,
generics,
quote! {#eq_token #ty #semi_token},
)
}
Item::Union(ItemUnion {
attrs,
vis,
union_token,
ident,
generics,
fields,
}) => (
quote! {#(#attrs)* #vis #union_token #ident #generics},
ident,
generics,
quote! {#fields},
),
_ => panic!("{INVALID_ITEM}"),
};
GuardItem {
ident,
decl,
generics,
cont,
}
}
}
impl From<TraitItem> for GuardItem {
fn from(item: TraitItem) -> Self {
let (decl, ident, generics, cont) = match item {
TraitItem::Method(TraitItemMethod {
attrs,
sig,
default,
semi_token,
}) => (
quote! {#(#attrs)* #sig},
sig.ident,
sig.generics,
quote! {#default #semi_token},
),
TraitItem::Type(TraitItemType {
attrs,
type_token,
ident,
generics,
colon_token,
bounds,
default,
semi_token,
}) => {
let default = default.map(|(eq, ty)| quote! {#eq #ty});
(
quote! {#(#attrs)* #type_token #ident},
ident,
generics,
quote! { #colon_token #bounds #default #semi_token},
)
}
_ => panic!("{INVALID_ITEM}"),
};
GuardItem {
ident,
decl,
generics,
cont,
}
}
}