use proc_macro::TokenStream;
use quote::quote;
use syn::{
DeriveInput, GenericParam, LifetimeParam, Meta, Token, WhereClause, parse_macro_input,
punctuated::Punctuated,
};
#[proc_macro_attribute]
pub fn context_deserialize(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr with Punctuated::<Meta, Token![,]>::parse_terminated);
let input = parse_macro_input!(item as DeriveInput);
let ident = &input.ident;
let mut ctx_types = Vec::new();
let mut explicit_where: Option<WhereClause> = None;
for meta in args {
match meta {
Meta::Path(p) => {
ctx_types.push(p);
}
Meta::NameValue(nv) if nv.path.is_ident("bound") => {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}) = &nv.value
{
let where_string = format!("where {}", lit_str.value());
match syn::parse_str::<WhereClause>(&where_string) {
Ok(where_clause) => {
explicit_where = Some(where_clause);
}
Err(err) => {
return syn::Error::new_spanned(
lit_str,
format!("Invalid where clause '{}': {}", lit_str.value(), err),
)
.to_compile_error()
.into();
}
}
} else {
return syn::Error::new_spanned(
&nv,
"Expected a string literal for `bound` value",
)
.to_compile_error()
.into();
}
}
_ => {
return syn::Error::new_spanned(
&meta,
"Expected paths or `bound = \"...\"` in #[context_deserialize(...)]",
)
.to_compile_error()
.into();
}
}
}
if ctx_types.is_empty() {
return quote! {
compile_error!("Usage: #[context_deserialize(Type1, Type2, ..., bound = \"...\")]");
}
.into();
}
let original_generics = input.generics.clone();
let mut impl_generics = input.generics.clone();
for param in impl_generics.params.iter_mut() {
if let GenericParam::Type(ty) = param {
ty.eq_token = None;
ty.default = None;
}
}
let has_de = impl_generics
.lifetimes()
.any(|LifetimeParam { lifetime, .. }| lifetime.ident == "de");
if !has_de {
impl_generics.params.insert(0, syn::parse_quote! { 'de });
}
let (_, ty_generics, _) = original_generics.split_for_impl();
let (impl_gens, _, _) = impl_generics.split_for_impl();
let mut impls = quote! {};
for ctx in ctx_types {
impls.extend(quote! {
impl #impl_gens context_deserialize::ContextDeserialize<'de, #ctx>
for #ident #ty_generics
#explicit_where
{
fn context_deserialize<D>(
deserializer: D,
_context: #ctx,
) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
<Self as serde::Deserialize>::deserialize(deserializer)
}
}
});
}
quote! {
#input
#impls
}
.into()
}