use amplify::proc_attr::ParametrizedAttr;
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{ToTokens, TokenStreamExt};
use syn::spanned::Spanned;
use syn::{
Data, DataEnum, DataStruct, DeriveInput, Error, Field, Fields, Ident,
ImplGenerics, Index, Result, TypeGenerics, WhereClause,
};
use crate::param::{EncodingDerive, TlvDerive, CRATE, REPR, USE_TLV};
pub fn encode_derive(
attr_name: &'static str,
crate_name: Ident,
trait_name: Ident,
encode_name: Ident,
serialize_name: Ident,
input: DeriveInput,
tlv_encoding: bool,
) -> Result<TokenStream2> {
let (impl_generics, ty_generics, where_clause) =
input.generics.split_for_impl();
let ident_name = &input.ident;
let global_param = ParametrizedAttr::with(attr_name, &input.attrs)?;
match input.data {
Data::Struct(data) => encode_struct_impl(
attr_name,
&crate_name,
&trait_name,
&encode_name,
&serialize_name,
data,
ident_name,
global_param,
impl_generics,
ty_generics,
where_clause,
tlv_encoding,
),
Data::Enum(data) => encode_enum_impl(
attr_name,
&crate_name,
&trait_name,
&encode_name,
&serialize_name,
data,
ident_name,
global_param,
impl_generics,
ty_generics,
where_clause,
),
Data::Union(_) => Err(Error::new_spanned(
&input,
format!("Deriving `{}` is not supported in unions", trait_name),
)),
}
}
#[allow(clippy::too_many_arguments)]
fn encode_struct_impl(
attr_name: &'static str,
crate_name: &Ident,
trait_name: &Ident,
encode_name: &Ident,
serialize_name: &Ident,
data: DataStruct,
ident_name: &Ident,
mut global_param: ParametrizedAttr,
impl_generics: ImplGenerics,
ty_generics: TypeGenerics,
where_clause: Option<&WhereClause>,
tlv_encoding: bool,
) -> Result<TokenStream2> {
let encoding = EncodingDerive::with(
&mut global_param,
crate_name,
true,
false,
false,
)?;
if !tlv_encoding && encoding.tlv.is_some() {
return Err(Error::new(
ident_name.span(),
format!("TLV extensions are not allowed in `{}`", attr_name),
));
}
let inner_impl = match data.fields {
Fields::Named(ref fields) => encode_fields_impl(
attr_name,
crate_name,
trait_name,
encode_name,
serialize_name,
&fields.named,
global_param,
false,
tlv_encoding,
)?,
Fields::Unnamed(ref fields) => encode_fields_impl(
attr_name,
crate_name,
trait_name,
encode_name,
serialize_name,
&fields.unnamed,
global_param,
false,
tlv_encoding,
)?,
Fields::Unit => quote! { Ok(()) },
};
let import = encoding.use_crate;
Ok(quote! {
impl #impl_generics #import::#trait_name for #ident_name #ty_generics #where_clause {
fn #encode_name(&self, e: &mut impl ::std::io::Write) -> ::core::result::Result<(), #import::Error> {
use #import::#trait_name;
let data = self;
#inner_impl
Ok(())
}
}
})
}
#[allow(clippy::too_many_arguments)]
fn encode_enum_impl(
attr_name: &'static str,
crate_name: &Ident,
trait_name: &Ident,
encode_name: &Ident,
serialize_name: &Ident,
data: DataEnum,
ident_name: &Ident,
mut global_param: ParametrizedAttr,
impl_generics: ImplGenerics,
ty_generics: TypeGenerics,
where_clause: Option<&WhereClause>,
) -> Result<TokenStream2> {
let encoding =
EncodingDerive::with(&mut global_param, crate_name, true, true, false)?;
let repr = encoding.repr;
let mut inner_impl = TokenStream2::new();
for (order, variant) in data.variants.iter().enumerate() {
let mut local_param =
ParametrizedAttr::with(attr_name, &variant.attrs)?;
let _ = EncodingDerive::with(
&mut local_param,
crate_name,
false,
true,
false,
)?;
let mut combined = global_param.clone().merged(local_param.clone())?;
combined.args.remove(REPR);
combined.args.remove(CRATE);
let encoding = EncodingDerive::with(
&mut combined,
crate_name,
false,
true,
false,
)?;
if encoding.skip {
continue;
}
let captures = variant
.fields
.iter()
.enumerate()
.map(|(i, f)| {
f.ident.as_ref().map(Ident::to_token_stream).unwrap_or_else(
|| {
Ident::new(&format!("_{}", i), Span::call_site())
.to_token_stream()
},
)
})
.collect::<Vec<_>>();
let (field_impl, bra_captures_ket) = match variant.fields {
Fields::Named(ref fields) => (
encode_fields_impl(
attr_name,
crate_name,
trait_name,
encode_name,
serialize_name,
&fields.named,
local_param,
true,
false,
)?,
quote! { { #( #captures ),* } },
),
Fields::Unnamed(ref fields) => (
encode_fields_impl(
attr_name,
crate_name,
trait_name,
encode_name,
serialize_name,
&fields.unnamed,
local_param,
true,
false,
)?,
quote! { ( #( #captures ),* ) },
),
Fields::Unit => (TokenStream2::new(), TokenStream2::new()),
};
let captures = match captures.len() {
0 => quote! {},
_ => quote! { let data = ( #( #captures ),* , ); },
};
let ident = &variant.ident;
let value = match (encoding.value, encoding.by_order) {
(Some(val), _) => val.to_token_stream(),
(None, true) => Index::from(order as usize).to_token_stream(),
(None, false) => quote! { Self::#ident },
};
inner_impl.append_all(quote_spanned! { variant.span() =>
#[allow(clippy::unnecessary_cast)]
Self::#ident #bra_captures_ket => {
(#value as #repr).#encode_name(e)?;
#captures
#field_impl
}
});
}
let import = encoding.use_crate;
Ok(quote! {
impl #impl_generics #import::#trait_name for #ident_name #ty_generics #where_clause {
#[inline]
fn #encode_name(&self, e: &mut impl ::std::io::Write) -> ::core::result::Result<(), #import::Error> {
use #import::#trait_name;
match self {
#inner_impl
}
Ok(())
}
}
})
}
#[allow(clippy::too_many_arguments)]
fn encode_fields_impl<'a>(
attr_name: &'static str,
crate_name: &Ident,
_trait_name: &Ident,
encode_name: &Ident,
serialize_name: &Ident,
fields: impl IntoIterator<Item = &'a Field>,
mut parent_param: ParametrizedAttr,
is_enum: bool,
tlv_encoding: bool,
) -> Result<TokenStream2> {
let mut stream = TokenStream2::new();
let use_tlv = parent_param.args.contains_key(USE_TLV);
parent_param.args.remove(CRATE);
parent_param.args.remove(USE_TLV);
if !tlv_encoding && use_tlv {
return Err(Error::new(
Span::call_site(),
format!("TLV extensions are not allowed in `{}`", attr_name),
));
}
let mut strict_fields = vec![];
let mut tlv_fields = bmap! {};
let mut tlv_aggregator = None;
for (index, field) in fields.into_iter().enumerate() {
let mut local_param = ParametrizedAttr::with(attr_name, &field.attrs)?;
let _ = EncodingDerive::with(
&mut local_param,
crate_name,
false,
is_enum,
use_tlv,
)?;
let mut combined = parent_param.clone().merged(local_param)?;
let encoding = EncodingDerive::with(
&mut combined,
crate_name,
false,
is_enum,
use_tlv,
)?;
if encoding.skip {
continue;
}
let index = Index::from(index).to_token_stream();
let name = if is_enum {
index
} else {
field
.ident
.as_ref()
.map(Ident::to_token_stream)
.unwrap_or(index)
};
encoding.tlv.unwrap_or(TlvDerive::None).process(
field,
name,
&mut strict_fields,
&mut tlv_fields,
&mut tlv_aggregator,
)?;
}
for name in strict_fields {
stream.append_all(quote_spanned! { Span::call_site() =>
data.#name.#encode_name(e)?;
})
}
if use_tlv {
stream.append_all(quote_spanned! { Span::call_site() =>
let mut tlvs = internet2::tlv::Stream::default();
});
for (type_no, (name, optional)) in tlv_fields {
if optional {
stream.append_all(quote_spanned! { Span::call_site() =>
if let Some(val) = &data.#name {
tlvs.insert(#type_no.into(), val.#serialize_name()?);
}
});
} else {
stream.append_all(quote_spanned! { Span::call_site() =>
if data.#name.iter().count() > 0 {
tlvs.insert(#type_no.into(), data.#name.#serialize_name()?);
}
});
}
}
if let Some(name) = tlv_aggregator {
stream.append_all(quote_spanned! { Span::call_site() =>
for (type_no, val) in &data.#name {
tlvs.insert(*type_no, val);
}
});
}
stream.append_all(quote_spanned! { Span::call_site() =>
tlvs.#encode_name(e)?;
})
}
Ok(stream)
}