use syn::{Generics, Ident, visit::{Visit, self}, Type, TypePath, spanned::Spanned};
use std::iter;
struct ContainIdents<'a> {
result: bool,
idents: &'a[Ident]
}
impl<'a, 'ast> Visit<'ast> for ContainIdents<'a> {
fn visit_ident(&mut self, i: &'ast Ident) {
if self.idents.iter().any(|id| &id == &i) {
self.result = true;
}
}
}
fn type_contain_idents(ty: &Type, idents: &[Ident]) -> bool {
let mut visitor = ContainIdents { result: false, idents };
visitor.visit_type(ty);
visitor.result
}
struct TypePathStartsWithIdent<'a> {
result: bool,
ident: &'a Ident
}
impl<'a, 'ast> Visit<'ast> for TypePathStartsWithIdent<'a> {
fn visit_type_path(&mut self, i: &'ast TypePath) {
if let Some(segment) = i.path.segments.first().map(|v| v.into_value()) {
if &segment.ident == self.ident {
self.result = true;
return;
}
}
visit::visit_type_path(self, i);
}
}
fn type_path_or_sub_starts_with_ident(ty: &TypePath, ident: &Ident) -> bool {
let mut visitor = TypePathStartsWithIdent { result: false, ident };
visitor.visit_type_path(ty);
visitor.result
}
fn type_or_sub_type_path_starts_with_ident(ty: &Type, ident: &Ident) -> bool {
let mut visitor = TypePathStartsWithIdent { result: false, ident };
visitor.visit_type(ty);
visitor.result
}
struct FindTypePathsNotStartOrContainIdent<'a> {
result: Vec<TypePath>,
ident: &'a Ident
}
impl<'a, 'ast> Visit<'ast> for FindTypePathsNotStartOrContainIdent<'a> {
fn visit_type_path(&mut self, i: &'ast TypePath) {
if type_path_or_sub_starts_with_ident(i, &self.ident) {
visit::visit_type_path(self, i);
} else {
self.result.push(i.clone());
}
}
}
fn find_type_paths_not_start_or_contain_ident(ty: &Type, ident: &Ident) -> Vec<TypePath> {
let mut visitor = FindTypePathsNotStartOrContainIdent { result: Vec::new(), ident };
visitor.visit_type(ty);
visitor.result
}
pub fn add(
input_ident: &Ident,
generics: &mut Generics,
data: &syn::Data,
codec_bound: syn::Path,
codec_skip_bound: Option<syn::Path>,
) -> syn::Result<()> {
let ty_params = generics.type_params().map(|p| p.ident.clone()).collect::<Vec<_>>();
if ty_params.is_empty() {
return Ok(());
}
let codec_types = collect_types(&data, needs_codec_bound, variant_not_skipped)?
.into_iter()
.filter(|ty| type_contain_idents(ty, &ty_params))
.flat_map(|ty| {
find_type_paths_not_start_or_contain_ident(&ty, input_ident)
.into_iter()
.map(|ty| Type::Path(ty.clone()))
.filter(|ty| type_contain_idents(ty, &ty_params))
.chain(iter::once(ty))
})
.filter(|ty| !type_or_sub_type_path_starts_with_ident(ty, input_ident))
.collect::<Vec<_>>();
let compact_types = collect_types(&data, needs_has_compact_bound, variant_not_skipped)?
.into_iter()
.filter(|ty| type_contain_idents(ty, &ty_params))
.collect::<Vec<_>>();
let skip_types = if codec_skip_bound.is_some() {
collect_types(&data, needs_default_bound, variant_not_skipped)?
.into_iter()
.filter(|ty| type_contain_idents(ty, &ty_params))
.collect::<Vec<_>>()
} else {
Vec::new()
};
if !codec_types.is_empty() || !compact_types.is_empty() || !skip_types.is_empty() {
let where_clause = generics.make_where_clause();
codec_types
.into_iter()
.for_each(|ty| {
where_clause.predicates.push(parse_quote!(#ty : #codec_bound))
});
let has_compact_bound: syn::Path = parse_quote!(_parity_scale_codec::HasCompact);
compact_types
.into_iter()
.for_each(|ty| {
where_clause.predicates.push(parse_quote!(#ty : #has_compact_bound))
});
skip_types
.into_iter()
.for_each(|ty| {
let codec_skip_bound = codec_skip_bound.as_ref().unwrap();
where_clause.predicates.push(parse_quote!(#ty : #codec_skip_bound))
});
}
Ok(())
}
fn needs_codec_bound(field: &syn::Field) -> bool {
!crate::utils::get_enable_compact(field)
&& crate::utils::get_encoded_as_type(field).is_none()
&& crate::utils::get_skip(&field.attrs).is_none()
}
fn needs_has_compact_bound(field: &syn::Field) -> bool {
crate::utils::get_enable_compact(field)
}
fn needs_default_bound(field: &syn::Field) -> bool {
crate::utils::get_skip(&field.attrs).is_some()
}
fn variant_not_skipped(variant: &syn::Variant) -> bool {
crate::utils::get_skip(&variant.attrs).is_none()
}
fn collect_types(
data: &syn::Data,
type_filter: fn(&syn::Field) -> bool,
variant_filter: fn(&syn::Variant) -> bool,
) -> syn::Result<Vec<syn::Type>> {
use syn::*;
let types = match *data {
Data::Struct(ref data) => match &data.fields {
| Fields::Named(FieldsNamed { named: fields , .. })
| Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) => {
fields.iter()
.filter(|f| type_filter(f))
.map(|f| f.ty.clone())
.collect()
},
Fields::Unit => { Vec::new() },
},
Data::Enum(ref data) => data.variants.iter()
.filter(|variant| variant_filter(variant))
.flat_map(|variant| {
match &variant.fields {
| Fields::Named(FieldsNamed { named: fields , .. })
| Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) => {
fields.iter()
.filter(|f| type_filter(f))
.map(|f| f.ty.clone())
.collect()
},
Fields::Unit => { Vec::new() },
}
}).collect(),
Data::Union(ref data) => return Err(Error::new(
data.union_token.span(),
"Union types are not supported."
)),
};
Ok(types)
}